diff --git a/.agents/skills/component-refactoring/SKILL.md b/.agents/skills/component-refactoring/SKILL.md index 0ed18d71d1..98a94592ab 100644 --- a/.agents/skills/component-refactoring/SKILL.md +++ b/.agents/skills/component-refactoring/SKILL.md @@ -367,7 +367,7 @@ For each extraction: ┌────────────────────────────────────────┐ │ 1. Extract code │ │ 2. Run: pnpm lint:fix │ - │ 3. Run: pnpm type-check:tsgo │ + │ 3. Run: pnpm type-check │ │ 4. Run: pnpm test │ │ 5. Test functionality manually │ │ 6. PASS? → Next extraction │ diff --git a/.agents/skills/e2e-cucumber-playwright/SKILL.md b/.agents/skills/e2e-cucumber-playwright/SKILL.md new file mode 100644 index 0000000000..de6b58f26d --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/SKILL.md @@ -0,0 +1,79 @@ +--- +name: e2e-cucumber-playwright +description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository. +--- + +# Dify E2E Cucumber + Playwright + +Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite. + +## Scope + +- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`. +- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead. +- Do not use this skill for backend test or API review tasks under `api/`. + +## Read Order + +1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first. +2. Read only the files directly involved in the task: + - target `.feature` files under `e2e/features/` + - related step files under `e2e/features/step-definitions/` + - `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters + - `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter +3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved. +4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved. +5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern. + +## Local Rules + +- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer. +- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions. +- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps. +- Browser session behavior comes from `features/support/hooks.ts`: + - default: authenticated session with shared storage state + - `@unauthenticated`: clean browser context + - `@authenticated`: readability/selective-run tag only unless implementation changes + - `@fresh`: only for `e2e:full*` flows +- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture. + +## Workflow + +1. Rebuild local context. + - Inspect the target feature area. + - Reuse an existing step when wording and behavior already match. + - Add a new step only for a genuinely new user action or assertion. + - Keep edits close to the current capability folder unless the step is broadly reusable. +2. Write behavior-first scenarios. + - Describe user-observable behavior, not DOM mechanics. + - Keep each scenario focused on one workflow or outcome. + - Keep scenarios independent and re-runnable. +3. Write step definitions in the local style. + - Keep one step to one user-visible action or one assertion. + - Prefer Cucumber Expressions such as `{string}` and `{int}`. + - Scope locators to stable containers when the page has repeated elements. + - Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them. +4. Use Playwright in the local style. + - Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts. + - Use web-first `expect(...)` assertions. + - Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior. +5. Validate narrowly. + - Run the narrowest tagged scenario or flow that exercises the change. + - Run `pnpm -C e2e check`. + - Broaden verification only when the change affects hooks, tags, setup, or shared step semantics. + +## Review Checklist + +- Does the scenario describe behavior rather than implementation? +- Does it fit the current session model, tags, and `DifyWorld` usage? +- Should an existing step be reused instead of adding a new one? +- Are locators user-facing and assertions web-first? +- Does the change introduce hidden coupling across scenarios, tags, or instance state? +- Does it document or implement behavior that differs from the real hooks or configuration? + +Lead findings with correctness, flake risk, and architecture drift. + +## References + +- [`references/playwright-best-practices.md`](references/playwright-best-practices.md) +- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) diff --git a/.agents/skills/e2e-cucumber-playwright/agents/openai.yaml b/.agents/skills/e2e-cucumber-playwright/agents/openai.yaml new file mode 100644 index 0000000000..605cce041d --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: "E2E Cucumber + Playwright" + short_description: "Write and review Dify E2E scenarios." + default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/." diff --git a/.agents/skills/e2e-cucumber-playwright/references/cucumber-best-practices.md b/.agents/skills/e2e-cucumber-playwright/references/cucumber-best-practices.md new file mode 100644 index 0000000000..d7a1a52852 --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/references/cucumber-best-practices.md @@ -0,0 +1,93 @@ +# Cucumber Best Practices For Dify E2E + +Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite. + +Official sources: + +- https://cucumber.io/docs/guides/10-minute-tutorial/ +- https://cucumber.io/docs/cucumber/step-definitions/ +- https://cucumber.io/docs/cucumber/cucumber-expressions/ + +## What Matters Most + +### 1. Treat scenarios as executable specifications + +Cucumber scenarios should describe examples of behavior, not test implementation recipes. + +Apply it like this: + +- write what the user does and what should happen +- avoid UI-internal wording such as selector details, DOM structure, or component names +- keep language concrete enough that the scenario reads like living documentation + +### 2. Keep scenarios focused + +A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it. + +In Dify's suite, this means: + +- one capability-focused scenario per feature path +- no long setup chains when existing bootstrap or reusable steps already cover them +- no hidden dependency on another scenario's side effects + +### 3. Reuse steps, but only when behavior really matches + +Good reuse reduces duplication. Bad reuse hides meaning. + +Prefer reuse when: + +- the user action is genuinely the same +- the expected outcome is genuinely the same +- the wording stays natural across features + +Write a new step when: + +- the behavior is materially different +- reusing the old wording would make the scenario misleading +- a supposedly generic step would become an implementation-detail wrapper + +### 4. Prefer Cucumber Expressions + +Use Cucumber Expressions for parameters unless regex is clearly necessary. + +Common examples: + +- `{string}` for labels, names, and visible text +- `{int}` for counts +- `{float}` for decimal values +- `{word}` only when the value is truly a single token + +Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler. + +### 5. Keep step definitions thin and meaningful + +Step definitions are glue between Gherkin and automation, not a second abstraction language. + +For Dify: + +- type `this` as `DifyWorld` +- use `async function` +- keep each step to one user-visible action or assertion +- rely on `DifyWorld` and existing support code for shared context +- avoid leaking cross-scenario state + +### 6. Use tags intentionally + +Tags should communicate run scope or session semantics, not become ad hoc metadata. + +In Dify's current suite: + +- capability tags group related scenarios +- `@unauthenticated` changes session behavior +- `@authenticated` is descriptive/selective, not a behavior switch by itself +- `@fresh` belongs to reset/full-install flows only + +If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it. + +## Review Questions + +- Does the scenario read like a real example of product behavior? +- Are the steps behavior-oriented instead of implementation-oriented? +- Is a reused step still truthful in this feature? +- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement? +- Would a new reader understand the outcome without opening the step-definition file? diff --git a/.agents/skills/e2e-cucumber-playwright/references/playwright-best-practices.md b/.agents/skills/e2e-cucumber-playwright/references/playwright-best-practices.md new file mode 100644 index 0000000000..02e763d46b --- /dev/null +++ b/.agents/skills/e2e-cucumber-playwright/references/playwright-best-practices.md @@ -0,0 +1,96 @@ +# Playwright Best Practices For Dify E2E + +Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite. + +Official sources: + +- https://playwright.dev/docs/best-practices +- https://playwright.dev/docs/locators +- https://playwright.dev/docs/test-assertions +- https://playwright.dev/docs/browser-contexts + +## What Matters Most + +### 1. Keep scenarios isolated + +Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`. + +Apply it like this: + +- do not depend on another scenario having run first +- do not persist ad hoc scenario state outside `DifyWorld` +- do not couple ordinary scenarios to `@fresh` behavior +- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes + +### 2. Prefer user-facing locators + +Playwright recommends built-in locators that reflect what users perceive on the page. + +Preferred order in this repository: + +1. `getByRole` +2. `getByLabel` +3. `getByPlaceholder` +4. `getByText` +5. `getByTestId` when an explicit test contract is the most stable option + +Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical. + +Also remember: + +- repeated content usually needs scoping to a stable container +- exact text matching is often too brittle when role/name or label already exists +- `getByTestId` is acceptable when semantics are weak but the contract is intentional + +### 3. Use web-first assertions + +Playwright assertions auto-wait and retry. Prefer them over manual state inspection. + +Prefer: + +- `await expect(page).toHaveURL(...)` +- `await expect(locator).toBeVisible()` +- `await expect(locator).toBeHidden()` +- `await expect(locator).toBeEnabled()` +- `await expect(locator).toHaveText(...)` + +Avoid: + +- `expect(await locator.isVisible()).toBe(true)` +- custom polling loops for DOM state +- `waitForTimeout` as synchronization + +If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit. + +### 4. Let actions wait for actionability + +Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity. + +Good pattern: + +- assert a meaningful visible state when that is part of the behavior +- then click/fill/select via locator APIs + +Bad pattern: + +- stack arbitrary waits before every action +- wait on unstable implementation details instead of the visible state the user cares about + +### 5. Match debugging to the current suite + +Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures: + +- full-page screenshots +- page HTML +- console errors +- page errors + +Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling. + +## Review Questions + +- Would this locator survive DOM refactors that do not change user-visible behavior? +- Is this assertion using Playwright's retrying semantics? +- Is any explicit wait masking a real readiness problem? +- Does this code preserve per-scenario isolation? +- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model? diff --git a/.agents/skills/frontend-query-mutation/SKILL.md b/.agents/skills/frontend-query-mutation/SKILL.md index 49888bdb66..10c49d222e 100644 --- a/.agents/skills/frontend-query-mutation/SKILL.md +++ b/.agents/skills/frontend-query-mutation/SKILL.md @@ -1,6 +1,6 @@ --- name: frontend-query-mutation -description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers. +description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers. --- # Frontend Query & Mutation @@ -9,22 +9,24 @@ description: Guide for implementing Dify frontend query and mutation patterns wi - Keep contract as the single source of truth in `web/contract/*`. - Prefer contract-shaped `queryOptions()` and `mutationOptions()`. -- Keep invalidation and mutation flow knowledge in the service layer. +- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough. +- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination. - Keep abstractions minimal to preserve TypeScript inference. ## Workflow 1. Identify the change surface. - Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape. - - Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations. + - Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations. - Read both references when a task spans contract shape and runtime behavior. 2. Implement the smallest abstraction that fits the task. - Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site. - Extract a small shared query helper only when multiple call sites share the same extra options. - - Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior. + - Create or keep feature hooks only for real orchestration or shared domain behavior. + - When touching thin `web/service/use-*` wrappers, migrate them away when feasible. 3. Preserve Dify conventions. - Keep contract inputs in `{ params, query?, body? }` shape. - - Bind invalidation in the service-layer mutation definition. + - Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior. - Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required. ## Files Commonly Touched @@ -33,7 +35,7 @@ description: Guide for implementing Dify frontend query and mutation patterns wi - `web/contract/marketplace.ts` - `web/contract/router.ts` - `web/service/client.ts` -- `web/service/use-*.ts` +- legacy `web/service/use-*.ts` files when migrating wrappers away - component and hook call sites using `consoleQuery` or `marketplaceQuery` ## References diff --git a/.agents/skills/frontend-query-mutation/agents/openai.yaml b/.agents/skills/frontend-query-mutation/agents/openai.yaml index 87f7ae6ea4..79e7e7d214 100644 --- a/.agents/skills/frontend-query-mutation/agents/openai.yaml +++ b/.agents/skills/frontend-query-mutation/agents/openai.yaml @@ -1,4 +1,4 @@ interface: display_name: "Frontend Query & Mutation" - short_description: "Dify TanStack Query and oRPC patterns" - default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations." + short_description: "Dify TanStack Query, oRPC, and default option patterns" + default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations." diff --git a/.agents/skills/frontend-query-mutation/references/contract-patterns.md b/.agents/skills/frontend-query-mutation/references/contract-patterns.md index 08016ed2cc..25ccfc81d7 100644 --- a/.agents/skills/frontend-query-mutation/references/contract-patterns.md +++ b/.agents/skills/frontend-query-mutation/references/contract-patterns.md @@ -7,6 +7,7 @@ - Core workflow - Query usage decision rule - Mutation usage decision rule +- Thin hook decision rule - Anti-patterns - Contract rules - Type export @@ -55,9 +56,13 @@ const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({ 1. Default to direct `*.queryOptions(...)` usage at the call site. 2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook. -3. Create `web/service/use-{domain}.ts` only for orchestration. +3. Create or keep feature hooks only for orchestration. - Combine multiple queries or mutations. - Share domain-level derived state or invalidation helpers. + - Prefer `web/features/{domain}/hooks/*` for feature-owned workflows. +4. Treat `web/service/use-{domain}.ts` as legacy. + - Do not create new thin service wrappers for oRPC contracts. + - When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough. ```typescript const invoicesBaseQueryOptions = () => @@ -74,11 +79,37 @@ const invoiceQuery = useQuery({ 1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`. 2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic. +```typescript +const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions()) +``` + +## Thin Hook Decision Rule + +Remove thin hooks when they only rename a single oRPC query or mutation helper. +Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API. +Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`. + +Use: + +```typescript +const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions()) +``` + +Keep: + +```typescript +const applyTagBindingsMutation = useApplyTagBindingsMutation() +``` + +`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough. + ## Anti-Patterns - Do not wrap `useQuery` with `options?: Partial`. - Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case. - Do not create thin `use-*` passthrough hooks for a single endpoint. +- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`. +- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs. - These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection. ## Contract Rules diff --git a/.agents/skills/frontend-query-mutation/references/runtime-rules.md b/.agents/skills/frontend-query-mutation/references/runtime-rules.md index 73d6fbdded..91b484d438 100644 --- a/.agents/skills/frontend-query-mutation/references/runtime-rules.md +++ b/.agents/skills/frontend-query-mutation/references/runtime-rules.md @@ -3,6 +3,7 @@ ## Table of Contents - Conditional queries +- oRPC default options - Cache invalidation - Key API guide - `mutate` vs `mutateAsync` @@ -35,9 +36,50 @@ function useBadAccessMode(appId: string | undefined) { } ``` +## oRPC Default Options + +Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation. + +Place defaults at the query utility creation point in `web/service/client.ts`: + +```typescript +export const consoleQuery = createTanstackQueryUtils(consoleClient, { + path: ['console'], + experimental_defaults: { + tags: { + create: { + mutationOptions: { + onSuccess: (tag, _variables, _result, context) => { + context.client.setQueryData( + consoleQuery.tags.list.queryKey({ + input: { + query: { + type: tag.type, + }, + }, + }), + (oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags, + ) + }, + }, + }, + }, + }, +}) +``` + +Rules: + +- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders. +- Do not create a wrapper function solely to host `createTanstackQueryUtils`. +- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`. +- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility. +- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation. + ## Cache Invalidation -Bind invalidation in the service-layer mutation definition. +Bind shared invalidation in oRPC defaults when it is tied to a contract operation. +Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default. Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate. Use: @@ -49,7 +91,7 @@ Use: Do not use deprecated `useInvalid` from `use-base.ts`. ```typescript -// Service layer owns cache invalidation. +// Feature orchestration owns cache invalidation only when defaults are not enough. export const useUpdateAccessMode = () => { const queryClient = useQueryClient() @@ -124,7 +166,7 @@ When touching old code, migrate it toward these rules: | Old pattern | New pattern | |---|---| -| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` | -| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition | +| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration | +| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook | | imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` | | `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` | diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 4da070bdbf..105c979c58 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path: - ✅ **Import real project components** directly (including base components and siblings) - ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers -- ❌ **DO NOT mock** base components (`@/app/components/base/*`) +- ❌ **DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`) - ❌ **DO NOT mock** sibling/child components in the same directory > See [Test Structure Template](#test-structure-template) for correct import/mock patterns. @@ -325,12 +325,12 @@ For more detailed information, refer to: ### Reference Examples in Codebase - `web/utils/classnames.spec.ts` - Utility function tests -- `web/app/components/base/button/index.spec.tsx` - Component tests +- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests - `web/__mocks__/provider-context.ts` - Mock factory example ### Project Configuration -- `web/vitest.config.ts` - Vitest configuration +- `web/vite.config.ts` - Vite/Vitest configuration - `web/vitest.setup.ts` - Test environment setup - `web/scripts/analyze-component.js` - Component analysis tool - Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files. diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md index 10b8fb66f9..519c3f166f 100644 --- a/.agents/skills/frontend-testing/references/checklist.md +++ b/.agents/skills/frontend-testing/references/checklist.md @@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Integration vs Mocking -- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.) - [ ] Import real project components instead of mocking - [ ] Only mock: API calls, complex context providers, third-party libs with side effects - [ ] Prefer integration testing when using single spec file @@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen ### Mocks -- [ ] **DO NOT mock base components** (`@/app/components/base/*`) +- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`) - [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`) - [ ] Shared mock state reset in `beforeEach` - [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations @@ -127,7 +127,7 @@ For the current file being tested: - [ ] Run full directory test: `pnpm test path/to/directory/` - [ ] Check coverage report: `pnpm test:coverage` - [ ] Run `pnpm lint:fix` on all test files -- [ ] Run `pnpm type-check:tsgo` +- [ ] Run `pnpm type-check` ## Common Issues to Watch diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index f58377c4a5..8c2f1c0c58 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -2,29 +2,27 @@ ## ⚠️ Important: What NOT to Mock -### DO NOT Mock Base Components +### DO NOT Mock Base Components or dify-ui Primitives -**Never mock components from `@/app/components/base/`** such as: +**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as: -- `Loading`, `Spinner` -- `Button`, `Input`, `Select` -- `Tooltip`, `Modal`, `Dropdown` -- `Icon`, `Badge`, `Tag` +- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag` +- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast` **Why?** -- Base components will have their own dedicated tests +- These components have their own dedicated tests - Mocking them creates false positives (tests pass but real integration fails) - Using real components tests actual integration behavior ```typescript -// ❌ WRONG: Don't mock base components +// ❌ WRONG: Don't mock base components or dify-ui primitives vi.mock('@/app/components/base/loading', () => () =>
Loading
) -vi.mock('@/app/components/base/button', () => ({ children }: any) => ) +vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: any) => })) -// ✅ CORRECT: Import and use real base components +// ✅ CORRECT: Import and use the real components import Loading from '@/app/components/base/loading' -import Button from '@/app/components/base/button' +import { Button } from '@langgenius/dify-ui/button' // They will render normally in tests ``` @@ -319,7 +317,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ### ✅ DO -1. **Use real base components** - Import from `@/app/components/base/` directly +1. **Use real base components and dify-ui primitives** - Import from `@/app/components/base/` or `@langgenius/dify-ui/*` directly 1. **Use real project components** - Prefer importing over mocking 1. **Use real Zustand stores** - Set test state via `store.setState()` 1. **Reset mocks in `beforeEach`**, not `afterEach` @@ -330,7 +328,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ### ❌ DON'T -1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.) +1. **Don't mock base components or dify-ui primitives** (`Loading`, `Input`, `Button`, `Tooltip`, `Dialog`, etc.) 1. **Don't mock Zustand store modules** - Use real stores with `setState()` 1. Don't mock components you can import directly 1. Don't create overly simplified mocks that miss conditional logic @@ -342,7 +340,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { ``` Need to use a component in test? │ -├─ Is it from @/app/components/base/*? +├─ Is it from @/app/components/base/* or @langgenius/dify-ui/*? │ └─ YES → Import real component, DO NOT mock │ ├─ Is it a project component? diff --git a/.claude/skills/e2e-cucumber-playwright b/.claude/skills/e2e-cucumber-playwright new file mode 120000 index 0000000000..71b0eae34f --- /dev/null +++ b/.claude/skills/e2e-cucumber-playwright @@ -0,0 +1 @@ +../../.agents/skills/e2e-cucumber-playwright \ No newline at end of file diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index b92d4c35a8..7460636824 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -7,7 +7,7 @@ cd web && pnpm install pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc -echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc +echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 94e857f93a..98b7e9f119 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,6 +6,9 @@ * @crazywoola @laipz8200 @Yeuoly +# ESLint suppression file is maintained by autofix.ci pruning. +/eslint-suppressions.json + # CODEOWNERS file /.github/CODEOWNERS @laipz8200 @crazywoola diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml index 673155bcf7..085b39ebfb 100644 --- a/.github/actions/setup-web/action.yml +++ b/.github/actions/setup-web/action.yml @@ -4,7 +4,7 @@ runs: using: composite steps: - name: Setup Vite+ - uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0 + uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0 with: node-version-file: .nvmrc cache: true diff --git a/.github/dependabot.yml b/.github/dependabot.yml index a183f0b58c..266fa17c29 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,106 +1,6 @@ version: 2 updates: - - package-ecosystem: "pip" - directory: "/api" - open-pull-requests-limit: 10 - schedule: - interval: "weekly" - groups: - flask: - patterns: - - "flask" - - "flask-*" - - "werkzeug" - - "gunicorn" - google: - patterns: - - "google-*" - - "googleapis-*" - opentelemetry: - patterns: - - "opentelemetry-*" - pydantic: - patterns: - - "pydantic" - - "pydantic-*" - llm: - patterns: - - "langfuse" - - "langsmith" - - "litellm" - - "mlflow*" - - "opik" - - "weave*" - - "arize*" - - "tiktoken" - - "transformers" - database: - patterns: - - "sqlalchemy" - - "psycopg2*" - - "psycogreen" - - "redis*" - - "alembic*" - storage: - patterns: - - "boto3*" - - "botocore*" - - "azure-*" - - "bce-*" - - "cos-python-*" - - "esdk-obs-*" - - "google-cloud-storage" - - "opendal" - - "oss2" - - "supabase*" - - "tos*" - vdb: - patterns: - - "alibabacloud*" - - "chromadb" - - "clickhouse-*" - - "clickzetta-*" - - "couchbase" - - "elasticsearch" - - "opensearch-py" - - "oracledb" - - "pgvect*" - - "pymilvus" - - "pymochow" - - "pyobvector" - - "qdrant-client" - - "intersystems-*" - - "tablestore" - - "tcvectordb" - - "tidb-vector" - - "upstash-*" - - "volcengine-*" - - "weaviate-*" - - "xinference-*" - - "mo-vector" - - "mysql-connector-*" - dev: - patterns: - - "coverage" - - "dotenv-linter" - - "faker" - - "lxml-stubs" - - "basedpyright" - - "ruff" - - "pytest*" - - "types-*" - - "boto3-stubs" - - "hypothesis" - - "pandas-stubs" - - "scipy-stubs" - - "import-linter" - - "celery-types" - - "mypy*" - - "pyrefly" - python-packages: - patterns: - - "*" - package-ecosystem: "uv" directory: "/api" open-pull-requests-limit: 10 diff --git a/.github/labeler.yml b/.github/labeler.yml index 3b9dc24749..e226bafccc 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -6,5 +6,4 @@ web: - 'package.json' - 'pnpm-lock.yaml' - 'pnpm-workspace.yaml' - - '.npmrc' - '.nvmrc' diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index a069b6cbc7..1e848612ec 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -7,6 +7,7 @@ ## Summary + ## Screenshots @@ -17,7 +18,7 @@ ## Checklist - [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs) -- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) -- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. -- [x] I've updated the documentation accordingly. -- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods +- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) +- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. +- [ ] I've updated the documentation accordingly. +- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods diff --git a/.github/workflows/anti-slop.yml b/.github/workflows/anti-slop.yml deleted file mode 100644 index b0f0a36bc9..0000000000 --- a/.github/workflows/anti-slop.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Anti-Slop PR Check - -on: - pull_request_target: - types: [opened, edited, synchronize] - -permissions: - pull-requests: write - contents: read - -jobs: - anti-slop: - runs-on: ubuntu-latest - steps: - - uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - close-pr: false - failure-add-pr-labels: "needs-revision" diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index cd967b76cf..bd47abc710 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -16,7 +16,7 @@ concurrency: jobs: api-unit: name: API Unit Tests - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 env: COVERAGE_FILE: coverage-unit defaults: @@ -35,7 +35,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -54,7 +54,7 @@ jobs: run: uv run --project api bash dev/pytest/pytest_unit_tests.sh - name: Upload unit coverage data - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: api-coverage-unit path: coverage-unit @@ -62,7 +62,7 @@ jobs: api-integration: name: API Integration Tests - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 env: COVERAGE_FILE: coverage-integration STORAGE_TYPE: opendal @@ -84,7 +84,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -105,7 +105,7 @@ jobs: run: sh .github/workflows/expose_service_ports.sh - name: Set up Sandbox - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -129,7 +129,7 @@ jobs: api/tests/test_containers_integration_tests - name: Upload integration coverage data - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: api-coverage-integration path: coverage-integration @@ -137,7 +137,7 @@ jobs: api-coverage: name: API Coverage - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 needs: - api-unit - api-integration @@ -156,7 +156,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 772ab8dd56..76fbd18f47 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -13,7 +13,7 @@ permissions: jobs: autofix: if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Complete merge group check if: github.event_name == 'merge_group' @@ -25,7 +25,7 @@ jobs: - name: Check Docker Compose inputs if: github.event_name != 'merge_group' id: docker-compose-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | docker/generate_docker_compose @@ -35,7 +35,7 @@ jobs: - name: Check web inputs if: github.event_name != 'merge_group' id: web-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | web/** @@ -43,12 +43,11 @@ jobs: package.json pnpm-lock.yaml pnpm-workspace.yaml - .npmrc .nvmrc - name: Check api inputs if: github.event_name != 'merge_group' id: api-changes - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | api/** @@ -58,7 +57,7 @@ jobs: python-version: "3.11" - if: github.event_name != 'merge_group' - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 - name: Generate Docker Compose if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true' @@ -114,14 +113,13 @@ jobs: find . -name "*.py.bak" -type f -delete - name: Setup web environment - if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' uses: ./.github/actions/setup-web - name: ESLint autofix if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' run: | - cd web vp exec eslint --concurrency=2 --prune-suppressions --quiet || true - if: github.event_name != 'merge_group' - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 + uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4 diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 79ecdb5938..915ed6cfe8 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -26,6 +26,9 @@ jobs: build: runs-on: ${{ matrix.runs_on }} if: github.repository == 'langgenius/dify' + permissions: + contents: read + id-token: write strategy: matrix: include: @@ -35,28 +38,28 @@ jobs: build_context: "{{defaultContext}}:api" file: "Dockerfile" platform: linux/amd64 - runs_on: ubuntu-latest + runs_on: depot-ubuntu-24.04-4 - service_name: "build-api-arm64" image_name_env: "DIFY_API_IMAGE_NAME" artifact_context: "api" build_context: "{{defaultContext}}:api" file: "Dockerfile" platform: linux/arm64 - runs_on: ubuntu-24.04-arm + runs_on: depot-ubuntu-24.04-4 - service_name: "build-web-amd64" image_name_env: "DIFY_WEB_IMAGE_NAME" artifact_context: "web" build_context: "{{defaultContext}}" file: "web/Dockerfile" platform: linux/amd64 - runs_on: ubuntu-latest + runs_on: depot-ubuntu-24.04-4 - service_name: "build-web-arm64" image_name_env: "DIFY_WEB_IMAGE_NAME" artifact_context: "web" build_context: "{{defaultContext}}" file: "web/Dockerfile" platform: linux/arm64 - runs_on: ubuntu-24.04-arm + runs_on: depot-ubuntu-24.04-4 steps: - name: Prepare @@ -70,8 +73,8 @@ jobs: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 + - name: Set up Depot CLI + uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1 - name: Extract metadata for Docker id: meta @@ -81,16 +84,15 @@ jobs: - name: Build Docker image id: build - uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 + uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0 with: + project: ${{ vars.DEPOT_PROJECT_ID }} context: ${{ matrix.build_context }} file: ${{ matrix.file }} platforms: ${{ matrix.platform }} build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }} labels: ${{ steps.meta.outputs.labels }} outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true - cache-from: type=gha,scope=${{ matrix.service_name }} - cache-to: type=gha,mode=max,scope=${{ matrix.service_name }} - name: Export digest env: @@ -101,16 +103,40 @@ jobs: touch "/tmp/digests/${sanitized_digest}" - name: Upload digest - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }} path: /tmp/digests/* if-no-files-found: error retention-days: 1 + fork-build-validate: + if: github.repository != 'langgenius/dify' + runs-on: ubuntu-24.04 + strategy: + matrix: + include: + - service_name: "validate-api-amd64" + build_context: "{{defaultContext}}:api" + file: "Dockerfile" + - service_name: "validate-web-amd64" + build_context: "{{defaultContext}}" + file: "web/Dockerfile" + steps: + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 + + - name: Validate Docker image + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0 + with: + push: false + context: ${{ matrix.build_context }} + file: ${{ matrix.file }} + platforms: linux/amd64 + create-manifest: needs: build - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 if: github.repository == 'langgenius/dify' strategy: matrix: diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index 5991abe3ba..65f0149a74 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -9,7 +9,7 @@ concurrency: jobs: db-migration-test-postgres: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Checkout code @@ -19,7 +19,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" @@ -40,7 +40,7 @@ jobs: cp middleware.env.example middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -59,7 +59,7 @@ jobs: run: uv run --directory api flask upgrade-db db-migration-test-mysql: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Checkout code @@ -69,7 +69,7 @@ jobs: persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" @@ -94,7 +94,7 @@ jobs: sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -110,6 +110,28 @@ jobs: sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env + # hoverkraft-tech/compose-action@v2.6.0 only waits for `docker compose up -d` + # to return (container processes started); it does not wait on healthcheck + # status. mysql:8.0's first-time init takes 15-30s, so without an explicit + # wait the migration runs while InnoDB is still initialising and gets + # killed with "Lost connection during query". Poll a real SELECT until it + # succeeds. + - name: Wait for MySQL to accept queries + run: | + set +e + for i in $(seq 1 60); do + if docker run --rm --network host mysql:8.0 \ + mysql -h 127.0.0.1 -P 3306 -uroot -pdifyai123456 \ + -e 'SELECT 1' >/dev/null 2>&1; then + echo "MySQL ready after ${i}s" + exit 0 + fi + sleep 1 + done + echo "MySQL not ready after 60s; dumping container logs:" + docker compose -f docker/docker-compose.middleware.yaml --profile mysql logs --tail=200 db_mysql + exit 1 + - name: Run DB Migration env: DEBUG: true diff --git a/.github/workflows/deploy-agent-dev.yml b/.github/workflows/deploy-agent-dev.yml index cd5fe9242e..9b9b77e0a2 100644 --- a/.github/workflows/deploy-agent-dev.yml +++ b/.github/workflows/deploy-agent-dev.yml @@ -13,7 +13,7 @@ on: jobs: deploy: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 if: | github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'deploy/agent-dev' diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index 954537663a..c2ff8c6332 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -10,7 +10,7 @@ on: jobs: deploy: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 if: | github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'deploy/dev' diff --git a/.github/workflows/deploy-enterprise.yml b/.github/workflows/deploy-enterprise.yml index 9cff3a3482..2740541f0f 100644 --- a/.github/workflows/deploy-enterprise.yml +++ b/.github/workflows/deploy-enterprise.yml @@ -13,7 +13,7 @@ on: jobs: deploy: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 if: | github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'deploy/enterprise' diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml index c6f1cc7e6f..0da241cf95 100644 --- a/.github/workflows/deploy-hitl.yml +++ b/.github/workflows/deploy-hitl.yml @@ -10,7 +10,7 @@ on: jobs: deploy: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 if: | github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'build/feat/hitl' diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index cd9d69d871..5144510be5 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -6,14 +6,7 @@ on: - "main" paths: - api/Dockerfile - - web/docker/** - web/Dockerfile - - packages/** - - package.json - - pnpm-lock.yaml - - pnpm-workspace.yaml - - .npmrc - - .nvmrc concurrency: group: docker-build-${{ github.head_ref || github.run_id }} @@ -21,28 +14,59 @@ concurrency: jobs: build-docker: + if: github.event.pull_request.head.repo.full_name == github.repository runs-on: ${{ matrix.runs_on }} + permissions: + contents: read + id-token: write strategy: matrix: include: - service_name: "api-amd64" platform: linux/amd64 - runs_on: ubuntu-latest + runs_on: depot-ubuntu-24.04-4 context: "{{defaultContext}}:api" file: "Dockerfile" - service_name: "api-arm64" platform: linux/arm64 - runs_on: ubuntu-24.04-arm + runs_on: depot-ubuntu-24.04-4 context: "{{defaultContext}}:api" file: "Dockerfile" - service_name: "web-amd64" platform: linux/amd64 - runs_on: ubuntu-latest + runs_on: depot-ubuntu-24.04-4 context: "{{defaultContext}}" file: "web/Dockerfile" - service_name: "web-arm64" platform: linux/arm64 - runs_on: ubuntu-24.04-arm + runs_on: depot-ubuntu-24.04-4 + context: "{{defaultContext}}" + file: "web/Dockerfile" + steps: + - name: Set up Depot CLI + uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1 + + - name: Build Docker Image + uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0 + with: + project: ${{ vars.DEPOT_PROJECT_ID }} + push: false + context: ${{ matrix.context }} + file: ${{ matrix.file }} + platforms: ${{ matrix.platform }} + + build-docker-fork: + if: github.event.pull_request.head.repo.full_name != github.repository + runs-on: ubuntu-24.04 + permissions: + contents: read + strategy: + matrix: + include: + - service_name: "api-amd64" + context: "{{defaultContext}}:api" + file: "Dockerfile" + - service_name: "web-amd64" context: "{{defaultContext}}" file: "web/Dockerfile" steps: @@ -50,11 +74,9 @@ jobs: uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 - name: Build Docker Image - uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 + uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0 with: push: false context: ${{ matrix.context }} file: ${{ matrix.file }} - platforms: ${{ matrix.platform }} - cache-from: type=gha - cache-to: type=gha,mode=max + platforms: linux/amd64 diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 278e10bc04..f59cc6be48 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -7,7 +7,7 @@ jobs: permissions: contents: read pull-requests: write - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1 with: diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index 59c38b6e7e..8071d6204d 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -23,7 +23,7 @@ concurrency: jobs: pre_job: name: Skip Duplicate Checks - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 outputs: should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }} steps: @@ -39,7 +39,7 @@ jobs: name: Check Changed Files needs: pre_job if: needs.pre_job.outputs.should_skip != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 outputs: api-changed: ${{ steps.changes.outputs.api }} e2e-changed: ${{ steps.changes.outputs.e2e }} @@ -69,7 +69,6 @@ jobs: - 'package.json' - 'pnpm-lock.yaml' - 'pnpm-workspace.yaml' - - '.npmrc' - '.nvmrc' - '.github/workflows/web-tests.yml' - '.github/actions/setup-web/**' @@ -83,7 +82,6 @@ jobs: - 'package.json' - 'pnpm-lock.yaml' - 'pnpm-workspace.yaml' - - '.npmrc' - '.nvmrc' - 'docker/docker-compose.middleware.yaml' - 'docker/middleware.env.example' @@ -92,6 +90,7 @@ jobs: vdb: - 'api/core/rag/datasource/**' - 'api/tests/integration_tests/vdb/**' + - 'api/providers/vdb/*/tests/**' - '.github/workflows/vdb-tests.yml' - '.github/workflows/expose_service_ports.sh' - 'docker/.env.example' @@ -140,7 +139,7 @@ jobs: - pre_job - check-changes if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Report skipped API tests run: echo "No API-related changes detected; skipping API tests." @@ -153,7 +152,7 @@ jobs: - check-changes - api-tests-run - api-tests-skip - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Finalize API Tests status env: @@ -200,7 +199,7 @@ jobs: - pre_job - check-changes if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Report skipped web tests run: echo "No web-related changes detected; skipping web tests." @@ -213,7 +212,7 @@ jobs: - check-changes - web-tests-run - web-tests-skip - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Finalize Web Tests status env: @@ -259,7 +258,7 @@ jobs: - pre_job - check-changes if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Report skipped web full-stack e2e run: echo "No E2E-related changes detected; skipping web full-stack E2E." @@ -272,7 +271,7 @@ jobs: - check-changes - web-e2e-run - web-e2e-skip - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Finalize Web Full-Stack E2E status env: @@ -324,7 +323,7 @@ jobs: - pre_job - check-changes if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Report skipped VDB tests run: echo "No VDB-related changes detected; skipping VDB tests." @@ -337,7 +336,7 @@ jobs: - check-changes - vdb-tests-run - vdb-tests-skip - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Finalize VDB Tests status env: @@ -383,7 +382,7 @@ jobs: - pre_job - check-changes if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Report skipped DB migration tests run: echo "No migration-related changes detected; skipping DB migration tests." @@ -396,7 +395,7 @@ jobs: - check-changes - db-migration-test-run - db-migration-test-skip - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Finalize DB Migration Test status env: diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index 0278e1e0d3..7f82942e7e 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -12,7 +12,7 @@ permissions: {} jobs: comment: name: Comment PR with pyrefly diff - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 permissions: actions: read contents: read @@ -21,7 +21,7 @@ jobs: if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }} steps: - name: Download pyrefly diff artifact - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -49,7 +49,7 @@ jobs: run: unzip -o pyrefly_diff.zip - name: Post comment - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -76,13 +76,11 @@ jobs: diff += '\\n\\n... (truncated) ...'; } - const body = diff.trim() - ? '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
' - : '### Pyrefly Diff\nNo changes detected.'; - - await github.rest.issues.createComment({ - issue_number: prNumber, - owner: context.repo.owner, - repo: context.repo.repo, - body, - }); + if (diff.trim()) { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body: '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
', + }); + } diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index 8623d35b04..0cf54e3585 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -10,7 +10,7 @@ permissions: jobs: pyrefly-diff: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 permissions: contents: read issues: write @@ -22,7 +22,7 @@ jobs: fetch-depth: 0 - name: Setup Python & UV - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true @@ -66,7 +66,7 @@ jobs: echo ${{ github.event.pull_request.number }} > pr_number.txt - name: Upload pyrefly diff - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: pyrefly_diff path: | @@ -75,7 +75,7 @@ jobs: - name: Comment PR with pyrefly diff if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }} - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml new file mode 100644 index 0000000000..52c16f3153 --- /dev/null +++ b/.github/workflows/pyrefly-type-coverage-comment.yml @@ -0,0 +1,118 @@ +name: Comment with Pyrefly Type Coverage + +on: + workflow_run: + workflows: + - Pyrefly Type Coverage + types: + - completed + +permissions: {} + +jobs: + comment: + name: Comment PR with type coverage + runs-on: depot-ubuntu-24.04 + permissions: + actions: read + contents: read + issues: write + pull-requests: write + if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }} + steps: + - name: Checkout default branch (trusted code) + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Setup Python & UV + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Download type coverage artifact + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const artifacts = await github.rest.actions.listWorkflowRunArtifacts({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: ${{ github.event.workflow_run.id }}, + }); + const match = artifacts.data.artifacts.find((artifact) => + artifact.name === 'pyrefly_type_coverage' + ); + if (!match) { + throw new Error('pyrefly_type_coverage artifact not found'); + } + const download = await github.rest.actions.downloadArtifact({ + owner: context.repo.owner, + repo: context.repo.repo, + artifact_id: match.id, + archive_format: 'zip', + }); + fs.writeFileSync('pyrefly_type_coverage.zip', Buffer.from(download.data)); + + - name: Unzip artifact + run: unzip -o pyrefly_type_coverage.zip + + - name: Render coverage markdown from structured data + id: render + run: | + comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \ + --base base_report.json \ + < pr_report.json)" + + { + echo "### Pyrefly Type Coverage" + echo "" + echo "$comment_body" + } > /tmp/type_coverage_comment.md + + - name: Post comment + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' }); + let prNumber = null; + try { + prNumber = parseInt(fs.readFileSync('pr_number.txt', { encoding: 'utf8' }), 10); + } catch (err) { + const prs = context.payload.workflow_run.pull_requests || []; + if (prs.length > 0 && prs[0].number) { + prNumber = prs[0].number; + } + } + if (!prNumber) { + throw new Error('PR number not found in artifact or workflow_run payload'); + } + + // Update existing comment if one exists, otherwise create new + const { data: comments } = await github.rest.issues.listComments({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + }); + const marker = '### Pyrefly Type Coverage'; + const existing = comments.find(c => c.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/.github/workflows/pyrefly-type-coverage.yml b/.github/workflows/pyrefly-type-coverage.yml new file mode 100644 index 0000000000..eae8debf1a --- /dev/null +++ b/.github/workflows/pyrefly-type-coverage.yml @@ -0,0 +1,120 @@ +name: Pyrefly Type Coverage + +on: + pull_request: + paths: + - 'api/**/*.py' + +permissions: + contents: read + +jobs: + pyrefly-type-coverage: + runs-on: depot-ubuntu-24.04 + permissions: + contents: read + issues: write + pull-requests: write + steps: + - name: Checkout PR branch + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + + - name: Setup Python & UV + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Run pyrefly report on PR branch + run: | + uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_pr.tmp && \ + mv /tmp/pyrefly_report_pr.tmp /tmp/pyrefly_report_pr.json || \ + echo '{}' > /tmp/pyrefly_report_pr.json + + - name: Save helper script from base branch + run: | + git show ${{ github.event.pull_request.base.sha }}:api/libs/pyrefly_type_coverage.py > /tmp/pyrefly_type_coverage.py 2>/dev/null \ + || cp api/libs/pyrefly_type_coverage.py /tmp/pyrefly_type_coverage.py + + - name: Checkout base branch + run: git checkout ${{ github.base_ref }} + + - name: Run pyrefly report on base branch + run: | + uv run --directory api --dev pyrefly report 2>/dev/null > /tmp/pyrefly_report_base.tmp && \ + mv /tmp/pyrefly_report_base.tmp /tmp/pyrefly_report_base.json || \ + echo '{}' > /tmp/pyrefly_report_base.json + + - name: Generate coverage comparison + id: coverage + run: | + comment_body="$(uv run --directory api python /tmp/pyrefly_type_coverage.py \ + --base /tmp/pyrefly_report_base.json \ + < /tmp/pyrefly_report_pr.json)" + + { + echo "### Pyrefly Type Coverage" + echo "" + echo "$comment_body" + } | tee -a "$GITHUB_STEP_SUMMARY" > /tmp/type_coverage_comment.md + + # Save structured data for the fork-PR comment workflow + cp /tmp/pyrefly_report_pr.json pr_report.json + cp /tmp/pyrefly_report_base.json base_report.json + + - name: Save PR number + run: | + echo ${{ github.event.pull_request.number }} > pr_number.txt + + - name: Upload type coverage artifact + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + with: + name: pyrefly_type_coverage + path: | + pr_report.json + base_report.json + pr_number.txt + + - name: Comment PR with type coverage + if: ${{ github.event.pull_request.head.repo.full_name == github.repository }} + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const marker = '### Pyrefly Type Coverage'; + let body; + try { + body = fs.readFileSync('/tmp/type_coverage_comment.md', { encoding: 'utf8' }); + } catch { + body = `${marker}\n\n_Coverage report unavailable._`; + } + const prNumber = context.payload.pull_request.number; + + // Update existing comment if one exists, otherwise create new + const { data: comments } = await github.rest.issues.listComments({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + }); + const existing = comments.find(c => c.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml index 49d2e94695..6f3193bbf5 100644 --- a/.github/workflows/semantic-pull-request.yml +++ b/.github/workflows/semantic-pull-request.yml @@ -16,7 +16,7 @@ jobs: name: Validate PR title permissions: pull-requests: read - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Complete merge group check if: github.event_name == 'merge_group' diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 5cf52daed2..b23648c7c6 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -12,7 +12,7 @@ on: jobs: stale: - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 permissions: issues: write pull-requests: write @@ -23,8 +23,8 @@ jobs: days-before-issue-stale: 15 days-before-issue-close: 3 repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it." - stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it." + stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it." + stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it." stale-issue-label: 'no-issue-activity' stale-pr-label: 'no-pr-activity' - any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted' + any-of-labels: '🌚 invalid,🙋‍♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted' diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index c32fc9d0cb..4ce121ba60 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -15,7 +15,7 @@ permissions: jobs: python-style: name: Python Style - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Checkout code @@ -25,7 +25,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | api/** @@ -33,7 +33,7 @@ jobs: - name: Setup UV and Python if: steps.changed-files.outputs.any_changed == 'true' - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: false python-version: "3.12" @@ -57,7 +57,7 @@ jobs: web-style: name: Web Style - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 defaults: run: working-directory: ./web @@ -73,15 +73,16 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | web/** + e2e/** + sdks/nodejs-client/** packages/** package.json pnpm-lock.yaml pnpm-workspace.yaml - .npmrc .nvmrc .github/workflows/style.yml .github/actions/setup-web/** @@ -93,26 +94,28 @@ jobs: - name: Restore ESLint cache if: steps.changed-files.outputs.any_changed == 'true' id: eslint-cache-restore - uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 with: - path: web/.eslintcache - key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} + path: .eslintcache + key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }} restore-keys: | - ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- + ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}- - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web + working-directory: . run: vp run lint:ci - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web + env: + NODE_OPTIONS: --max-old-space-size=4096 run: vp run lint:tss - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web + working-directory: . run: vp run type-check - name: Web dead code check @@ -122,14 +125,14 @@ jobs: - name: Save ESLint cache if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true' - uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5 with: - path: web/.eslintcache + path: .eslintcache key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }} superlinter: name: SuperLinter - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 steps: - name: Checkout code @@ -140,7 +143,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6 with: files: | **.sh diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index 467f31fccf..adaf99f33a 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -9,7 +9,6 @@ on: - package.json - pnpm-lock.yaml - pnpm-workspace.yaml - - .npmrc concurrency: group: sdk-tests-${{ github.head_ref || github.run_id }} @@ -18,7 +17,7 @@ concurrency: jobs: build: name: unit test for Node.js SDK - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 defaults: run: @@ -30,7 +29,7 @@ jobs: persist-credentials: false - name: Use Node.js - uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 + uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0 with: node-version: 22 cache: '' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index e001f4d677..7bb6fc1bbd 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -35,7 +35,7 @@ concurrency: jobs: translate: if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 timeout-minutes: 120 steps: @@ -158,7 +158,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.context.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89 + uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml index 9a11d3e8df..87c88e2023 100644 --- a/.github/workflows/trigger-i18n-sync.yml +++ b/.github/workflows/trigger-i18n-sync.yml @@ -16,7 +16,7 @@ concurrency: jobs: trigger: if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 timeout-minutes: 5 steps: @@ -56,7 +56,7 @@ jobs: - name: Trigger i18n sync workflow if: steps.detect.outputs.has_changes == 'true' - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 env: BASE_SHA: ${{ steps.detect.outputs.base_sha }} HEAD_SHA: ${{ steps.detect.outputs.head_sha }} diff --git a/.github/workflows/vdb-tests-full.yml b/.github/workflows/vdb-tests-full.yml index 72b3ea9aac..5c241af5c5 100644 --- a/.github/workflows/vdb-tests-full.yml +++ b/.github/workflows/vdb-tests-full.yml @@ -16,7 +16,7 @@ jobs: test: name: Full VDB Tests if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 strategy: matrix: python-version: @@ -36,7 +36,7 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -65,7 +65,7 @@ jobs: # tiflash - name: Set up Full Vector Store Matrix - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.yaml @@ -89,7 +89,7 @@ jobs: cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env # - name: Check VDB Ready (TiDB) -# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py +# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py - name: Test Vector Stores run: uv run --project api bash dev/pytest/pytest_vdb.sh diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 47ec70f603..38ec96f00f 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -13,7 +13,7 @@ concurrency: jobs: test: name: VDB Smoke Tests - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04 strategy: matrix: python-version: @@ -33,7 +33,7 @@ jobs: remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -62,7 +62,7 @@ jobs: # tiflash - name: Set up Vector Stores for Smoke Coverage - uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 + uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0 with: compose-file: | docker/docker-compose.yaml @@ -81,12 +81,12 @@ jobs: cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env # - name: Check VDB Ready (TiDB) -# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py +# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py - name: Test Vector Stores run: | uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \ - api/tests/integration_tests/vdb/chroma \ - api/tests/integration_tests/vdb/pgvector \ - api/tests/integration_tests/vdb/qdrant \ - api/tests/integration_tests/vdb/weaviate + api/providers/vdb/vdb-chroma/tests/integration_tests \ + api/providers/vdb/vdb-pgvector/tests/integration_tests \ + api/providers/vdb/vdb-qdrant/tests/integration_tests \ + api/providers/vdb/vdb-weaviate/tests/integration_tests diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml index eb752619be..bdc24887db 100644 --- a/.github/workflows/web-e2e.yml +++ b/.github/workflows/web-e2e.yml @@ -13,7 +13,7 @@ concurrency: jobs: test: name: Web Full-Stack E2E - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04-4 defaults: run: shell: bash @@ -28,7 +28,7 @@ jobs: uses: ./.github/actions/setup-web - name: Setup UV and Python - uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0 + uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 with: enable-cache: true python-version: "3.12" @@ -53,7 +53,7 @@ jobs: - name: Upload Cucumber report if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: cucumber-report path: e2e/cucumber-report @@ -61,7 +61,7 @@ jobs: - name: Upload E2E logs if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: e2e-logs path: e2e/.logs diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 3c36335e79..4619f3c104 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -16,7 +16,7 @@ concurrency: jobs: test: name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }}) - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04-4 env: VITEST_COVERAGE_SCOPE: app-components strategy: @@ -43,7 +43,7 @@ jobs: - name: Upload blob report if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 with: name: blob-report-${{ matrix.shardIndex }} path: web/.vitest-reports/* @@ -54,7 +54,7 @@ jobs: name: Merge Test Reports if: ${{ !cancelled() }} needs: [test] - runs-on: ubuntu-latest + runs-on: depot-ubuntu-24.04-4 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} defaults: @@ -89,3 +89,37 @@ jobs: flags: web env: CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} + + dify-ui-test: + name: dify-ui Tests + runs-on: depot-ubuntu-24.04-4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + defaults: + run: + shell: bash + working-directory: ./packages/dify-ui + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Setup web environment + uses: ./.github/actions/setup-web + + - name: Install Chromium for Browser Mode + run: vp exec playwright install --with-deps chromium + + - name: Run dify-ui tests + run: vp test run --coverage --silent=passed-only + + - name: Report coverage + if: ${{ env.CODECOV_TOKEN != '' }} + uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0 + with: + directory: packages/dify-ui/coverage + flags: dify-ui + env: + CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index 53dea88899..dc3b3f284f 100644 --- a/.gitignore +++ b/.gitignore @@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info .vscode/* !.vscode/launch.json.template +!.vscode/settings.example.json !.vscode/README.md api/.vscode # vscode Code History Extension @@ -218,6 +219,9 @@ node_modules # plugin migrate plugins.jsonl +# generated API OpenAPI specs +packages/contracts/openapi/ + # mise mise.toml @@ -236,9 +240,15 @@ scripts/stress-test/reports/ .playwright-mcp/ .serena/ +# vitest browser mode attachments (failure screenshots, traces, etc.) +.vitest-attachments/ +**/__screenshots__/ + # settings *.local.json *.local.md # Code Agent Folder .qoder/* + +.eslintcache diff --git a/.npmrc b/.npmrc deleted file mode 100644 index cffe8cdef1..0000000000 --- a/.npmrc +++ /dev/null @@ -1 +0,0 @@ -save-exact=true diff --git a/.vite-hooks/pre-commit b/.vite-hooks/pre-commit index 13bbd81cf6..d48381bce2 100755 --- a/.vite-hooks/pre-commit +++ b/.vite-hooks/pre-commit @@ -56,44 +56,9 @@ if $api_modified; then fi fi -if $web_modified; then - if $skip_web_checks; then - echo "Git operation in progress, skipping web checks" - exit 0 - fi - - echo "Running ESLint on web module" - - if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then - web_ts_modified=false - else - ts_diff_status=$? - if [ $ts_diff_status -eq 1 ]; then - web_ts_modified=true - else - echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)." - exit $ts_diff_status - fi - fi - - cd ./web || exit 1 - vp staged - - if $web_ts_modified; then - echo "Running TypeScript type-check:tsgo" - if ! npm run type-check:tsgo; then - echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors." - exit 1 - fi - else - echo "No staged TypeScript changes detected, skipping type-check:tsgo" - fi - - echo "Running knip" - if ! npm run knip; then - echo "Knip check failed. Please run 'npm run knip' to fix the errors." - exit 1 - fi - - cd ../ +if $skip_web_checks; then + echo "Git operation in progress, skipping web checks" + exit 0 fi + +vp staged diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index c3e2c50c52..2611b75c6c 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -2,21 +2,10 @@ "version": "0.2.0", "configurations": [ { - "name": "Python: Flask API", + "name": "Python: API (gevent)", "type": "debugpy", "request": "launch", - "module": "flask", - "env": { - "FLASK_APP": "app.py", - "FLASK_ENV": "development" - }, - "args": [ - "run", - "--host=0.0.0.0", - "--port=5001", - "--no-debugger", - "--no-reload" - ], + "program": "${workspaceFolder}/api/app.py", "jinja": true, "justMyCode": true, "cwd": "${workspaceFolder}/api", diff --git a/web/.vscode/settings.example.json b/.vscode/settings.example.json similarity index 86% rename from web/.vscode/settings.example.json rename to .vscode/settings.example.json index 4b356f5b7a..7cdbc51a3b 100644 --- a/web/.vscode/settings.example.json +++ b/.vscode/settings.example.json @@ -1,12 +1,16 @@ { - // Disable the default formatter, use eslint instead - "prettier.enable": false, - "editor.formatOnSave": false, + "cucumber.features": [ + "e2e/features/**/*.feature", + ], + "cucumber.glue": [ + "e2e/features/**/*.ts", + ], + + "tailwindCSS.experimental.configFile": "web/app/styles/globals.css", // Auto fix "editor.codeActionsOnSave": { "source.fixAll.eslint": "explicit", - "source.organizeImports": "never" }, // Silent the stylistic rules in your IDE, but still auto fix them diff --git a/AGENTS.md b/AGENTS.md index d25d2eed96..8be2daef95 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -30,7 +30,7 @@ The codebase is split into: ## Language Style - **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation. -- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types. +- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check`, and avoid `any` types. ## General Practices diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 775401bfa5..d7f007af67 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -97,8 +97,3 @@ Feel free to reach out if you encounter any issues during the setup process. ## Getting Help If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. - -## Automated Agent Contributions - -> [!NOTE] -> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in. diff --git a/README.md b/README.md index d9848a6c78..e6f8d84931 100644 --- a/README.md +++ b/README.md @@ -76,10 +76,11 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock ```bash cd dify cd docker -cp .env.example .env -docker compose up -d +./dify-compose up -d ``` +On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory. + After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. #### Seeking help @@ -137,20 +138,7 @@ Star Dify on GitHub and be instantly notified of new releases. ### Custom configurations -If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). - -#### Customizing Suggested Questions - -You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions: - -```bash -# In your .env file -SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]' -SUGGESTED_QUESTIONS_MAX_TOKENS=512 -SUGGESTED_QUESTIONS_TEMPERATURE=0.3 -``` - -See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions. +If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). ### Metrics Monitoring with Grafana @@ -160,7 +148,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source ### Deployment with Kubernetes -If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. +If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) diff --git a/api/.env.example b/api/.env.example index a04a18944a..f6f65011ea 100644 --- a/api/.env.example +++ b/api/.env.example @@ -33,6 +33,9 @@ TRIGGER_URL=http://localhost:5001 # The time in seconds after the signature is rejected FILES_ACCESS_TIMEOUT=300 +# Collaboration mode toggle +ENABLE_COLLABORATION_MODE=false + # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 @@ -57,6 +60,9 @@ REDIS_SSL_CERTFILE= REDIS_SSL_KEYFILE= # Path to client private key file for SSL authentication REDIS_DB=0 +# Optional global prefix for Redis keys, topics, streams, and Celery Redis transport artifacts. +# Leave empty to preserve current unprefixed behavior. +REDIS_KEY_PREFIX= # redis Sentinel configuration. REDIS_USE_SENTINEL=false @@ -653,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y MARKETPLACE_ENABLED=true MARKETPLACE_API_URL=https://marketplace.dify.ai +# Creators Platform configuration +CREATORS_PLATFORM_FEATURES_ENABLED=true +CREATORS_PLATFORM_API_URL=https://creators.dify.ai +CREATORS_PLATFORM_OAUTH_CLIENT_ID= + # Endpoint configuration ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} @@ -703,22 +714,6 @@ SWAGGER_UI_PATH=/swagger-ui.html # Set to false to export dataset IDs as plain text for easier cross-environment import DSL_EXPORT_ENCRYPT_DATASET_ID=true -# Suggested Questions After Answer Configuration -# These environment variables allow customization of the suggested questions feature -# -# Custom prompt for generating suggested questions (optional) -# If not set, uses the default prompt that generates 3 questions under 20 characters each -# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]" -# SUGGESTED_QUESTIONS_PROMPT= - -# Maximum number of tokens for suggested questions generation (default: 256) -# Adjust this value for longer questions or more questions -# SUGGESTED_QUESTIONS_MAX_TOKENS=256 - -# Temperature for suggested questions generation (default: 0.0) -# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions -# SUGGESTED_QUESTIONS_TEMPERATURE=0 - # Tenant isolated task queue configuration TENANT_ISOLATED_TASK_CONCURRENCY=1 diff --git a/api/.ruff.toml b/api/.ruff.toml index 2a825f1ef0..dd78024a02 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -69,8 +69,6 @@ ignore = [ "FURB152", # math-constant "UP007", # non-pep604-annotation "UP032", # f-string - "UP045", # non-pep604-annotation-optional - "B005", # strip-with-multi-characters "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg @@ -84,7 +82,6 @@ ignore = [ "SIM102", # collapsible-if "SIM103", # needless-bool "SIM105", # suppressible-exception - "SIM107", # return-in-try-except-finally "SIM108", # if-else-block-instead-of-if-exp "SIM113", # enumerate-for-loop "SIM117", # multiple-with-statements @@ -93,32 +90,22 @@ ignore = [ ] [lint.per-file-ignores] -"__init__.py" = [ - "F401", # unused-import - "F811", # redefined-while-unused -] "configs/*" = [ "N802", # invalid-function-name ] -"graphon/model_runtime/callbacks/base_callback.py" = ["T201"] -"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"] "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name ] "tests/*" = [ - "F811", # redefined-while-unused "T201", # allow print in tests, "S110", # allow ignoring exceptions in tests code (currently) - ] -"controllers/console/explore/trial.py" = ["TID251"] -"controllers/console/human_input_form.py" = ["TID251"] -"controllers/web/human_input_form.py" = ["TID251"] - -[lint.flake8-tidy-imports] [lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"] msg = "Use Pydantic payload/query models instead of reqparse." [lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"] msg = "Use Pydantic payload/query models instead of reqparse." + +[lint.isort] +known-first-party = ["graphon"] \ No newline at end of file diff --git a/api/.vscode/launch.json.example b/api/.vscode/launch.json.example index 6bdfa2c039..1001559176 100644 --- a/api/.vscode/launch.json.example +++ b/api/.vscode/launch.json.example @@ -3,29 +3,21 @@ "compounds": [ { "name": "Launch Flask and Celery", - "configurations": ["Python: Flask", "Python: Celery"] + "configurations": ["Python: API (gevent)", "Python: Celery"] } ], "configurations": [ { - "name": "Python: Flask", - "consoleName": "Flask", + "name": "Python: API (gevent)", + "consoleName": "API", "type": "debugpy", "request": "launch", "python": "${workspaceFolder}/.venv/bin/python", "cwd": "${workspaceFolder}", "envFile": ".env", - "module": "flask", + "program": "${workspaceFolder}/app.py", "justMyCode": true, - "jinja": true, - "env": { - "FLASK_APP": "app.py", - "GEVENT_SUPPORT": "True" - }, - "args": [ - "run", - "--port=5001" - ] + "jinja": true }, { "name": "Python: Celery", diff --git a/api/Dockerfile b/api/Dockerfile index 7e0a439954..6098652573 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -21,8 +21,9 @@ RUN apt-get update \ # for building gmpy2 libmpfr-dev libmpc-dev -# Install Python dependencies +# Install Python dependencies (workspace members under providers/vdb/) COPY pyproject.toml uv.lock ./ +COPY providers ./providers RUN uv sync --locked --no-dev # production stage diff --git a/api/README.md b/api/README.md index 00562f3f78..a075bc0fa9 100644 --- a/api/README.md +++ b/api/README.md @@ -101,3 +101,11 @@ The scripts resolve paths relative to their location, so you can run them from a uv run ruff format ./ # Format code uv run basedpyright . # Type checking ``` + +## Generate TS stub + +``` +uv run dev/generate_swagger_specs.py --output-dir openapi +``` + +use https://jsontotable.org/openapi-to-typescript to convert to typescript diff --git a/api/app.py b/api/app.py index c018c8a045..e53b037be5 100644 --- a/api/app.py +++ b/api/app.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import sys from typing import TYPE_CHECKING, cast @@ -9,17 +10,35 @@ if TYPE_CHECKING: celery: Celery +HOST = "0.0.0.0" +PORT = 5001 +logger = logging.getLogger(__name__) + + def is_db_command() -> bool: if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db": return True return False +def log_startup_banner(host: str, port: int) -> None: + debugger_attached = sys.gettrace() is not None + logger.info("Serving Dify API via gevent WebSocket server") + logger.info("Bound to http://%s:%s", host, port) + logger.info("Debugger attached: %s", "on" if debugger_attached else "off") + logger.info("Press CTRL+C to quit") + + # create app +flask_app = None +socketio_app = None + if is_db_command(): from app_factory import create_migrations_app app = create_migrations_app() + socketio_app = app + flask_app = app else: # Gunicorn and Celery handle monkey patching automatically in production by # specifying the `gevent` worker class. Manual monkey patching is not required here. @@ -30,8 +49,14 @@ else: from app_factory import create_app - app = create_app() + socketio_app, flask_app = create_app() + app = flask_app celery = cast("Celery", app.extensions["celery"]) if __name__ == "__main__": - app.run(host="0.0.0.0", port=5001) + from gevent import pywsgi + from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs] + + log_startup_banner(HOST, PORT) + server = pywsgi.WSGIServer((HOST, PORT), socketio_app, handler_class=WebSocketHandler) + server.serve_forever() diff --git a/api/app_factory.py b/api/app_factory.py index 76838f9925..48e50ceae9 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,6 +1,7 @@ import logging import time +import socketio # type: ignore[reportMissingTypeStubs] from flask import request from opentelemetry.trace import get_current_span from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID @@ -10,6 +11,7 @@ from contexts.wrapper import RecyclableContextVar from controllers.console.error import UnauthorizedAndForceLogout from core.logging.context import init_request_context from dify_app import DifyApp +from extensions.ext_socketio import sio from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import LicenseStatus @@ -122,14 +124,18 @@ def create_flask_app_with_configs() -> DifyApp: return dify_app -def create_app() -> DifyApp: +def create_app() -> tuple[socketio.WSGIApp, DifyApp]: start_time = time.perf_counter() app = create_flask_app_with_configs() initialize_extensions(app) + + sio.app = app + socketio_app = socketio.WSGIApp(sio, app) + end_time = time.perf_counter() if dify_config.DEBUG: logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2)) - return app + return socketio_app, app def initialize_extensions(app: DifyApp): diff --git a/api/commands/account.py b/api/commands/account.py index 84af7a5ae6..0d99ce7a0f 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -2,7 +2,7 @@ import base64 import secrets import click -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from constants.languages import languages from extensions.ext_database import db @@ -25,30 +25,32 @@ def reset_password(email, new_password, password_confirm): return normalized_email = email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - valid_password(new_password) - except: - click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) - return + try: + valid_password(new_password) + except: + click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) + return - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() - # encrypt password with salt - password_hashed = hash_password(new_password, salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + with Session(db.engine) as session: + account = session.merge(account) account.password = base64_password_hashed account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(normalized_email) - click.echo(click.style("Password reset successfully.", fg="green")) + session.commit() + AccountService.reset_login_error_rate_limit(normalized_email) + click.echo(click.style("Password reset successfully.", fg="green")) @click.command("reset-email", help="Reset the account email.") @@ -65,21 +67,23 @@ def reset_email(email, new_email, email_confirm): return normalized_new_email = new_email.strip().lower() - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + account = AccountService.get_account_by_email_with_case_fallback(email.strip()) - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return - try: - email_validate(normalized_new_email) - except: - click.echo(click.style(f"Invalid email: {new_email}", fg="red")) - return + try: + email_validate(normalized_new_email) + except: + click.echo(click.style(f"Invalid email: {new_email}", fg="red")) + return + with Session(db.engine) as session: + account = session.merge(account) account.email = normalized_new_email - click.echo(click.style("Email updated successfully.", fg="green")) + session.commit() + click.echo(click.style("Email updated successfully.", fg="green")) @click.command("create-tenant", help="Create account and tenant.") @@ -109,8 +113,18 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No # Validates name encoding for non-Latin characters. name = name.strip().encode("utf-8").decode("utf-8") if name else None - # generate random password - new_password = secrets.token_urlsafe(16) + # Generate a random password that satisfies the password policy. + # The iteration limit guards against infinite loops caused by unexpected bugs in valid_password. + for _ in range(100): + new_password = secrets.token_urlsafe(16) + try: + valid_password(new_password) + break + except Exception: + continue + else: + click.echo(click.style("Failed to generate a valid password. Please try again.", fg="red")) + return # register account account = RegisterService.register( diff --git a/api/commands/plugin.py b/api/commands/plugin.py index c34391025a..8bd5392d7b 100644 --- a/api/commands/plugin.py +++ b/api/commands/plugin.py @@ -11,7 +11,7 @@ from configs import dify_config from core.helper import encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.plugin import PluginInstaller -from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params +from core.tools.utils.system_encryption import encrypt_system_params from extensions.ext_database import db from models import Tenant from models.oauth import DatasourceOauthParamConfig, DatasourceProvider @@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params): click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) + oauth_client_params = encrypt_system_params(client_params_dict) click.echo(click.style("Client params encrypted successfully.", fg="green")) except Exception as e: click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) @@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params): click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) + oauth_client_params = encrypt_system_params(client_params_dict) click.echo(click.style("Client params encrypted successfully.", fg="green")) except Exception as e: click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) diff --git a/api/commands/vector.py b/api/commands/vector.py index cb7eb7c452..956f20d6bb 100644 --- a/api/commands/vector.py +++ b/api/commands/vector.py @@ -341,11 +341,10 @@ def add_qdrant_index(field: str): click.echo(click.style("No dataset collection bindings found.", fg="red")) return import qdrant_client + from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import PayloadSchemaType - from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig - for binding in bindings: if dify_config.QDRANT_URL is None: raise ValueError("Qdrant URL is required.") diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d37cff63e9..52e33c1789 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings): ) +class CreatorsPlatformConfig(BaseSettings): + """ + Configuration for Creators Platform integration + """ + + CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field( + description="Enable or disable Creators Platform features", + default=True, + ) + + CREATORS_PLATFORM_API_URL: HttpUrl = Field( + description="Creators Platform API URL", + default=HttpUrl("https://creators.dify.ai"), + ) + + CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field( + description="OAuth client ID for Creators Platform integration", + default="", + ) + + class EndpointConfig(BaseSettings): """ Configuration for various application endpoints and URLs @@ -1274,6 +1295,13 @@ class PositionConfig(BaseSettings): return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} +class CollaborationConfig(BaseSettings): + ENABLE_COLLABORATION_MODE: bool = Field( + description="Whether to enable collaboration mode features across the workspace", + default=False, + ) + + class LoginConfig(BaseSettings): ENABLE_EMAIL_CODE_LOGIN: bool = Field( description="whether to enable email code login", @@ -1372,6 +1400,7 @@ class FeatureConfig( AuthConfig, # Changed from OAuthConfig to AuthConfig BillingConfig, CodeExecutionSandboxConfig, + CreatorsPlatformConfig, TriggerConfig, AsyncWorkflowConfig, PluginConfig, @@ -1399,6 +1428,7 @@ class FeatureConfig( WorkflowConfig, WorkflowNodeExecutionConfig, WorkspaceConfig, + CollaborationConfig, LoginConfig, AccountConfig, SwaggerUIConfig, diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 7140c16f9e..df7009e213 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,5 +1,5 @@ import os -from typing import Any, Literal +from typing import Any, Literal, TypedDict from urllib.parse import parse_qsl, quote_plus from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field @@ -112,6 +112,17 @@ class KeywordStoreConfig(BaseSettings): ) +class SQLAlchemyEngineOptionsDict(TypedDict): + pool_size: int + max_overflow: int + pool_recycle: int + pool_pre_ping: bool + connect_args: dict[str, str] + pool_use_lifo: bool + pool_reset_on_return: None + pool_timeout: int + + class DatabaseConfig(BaseSettings): # Database type selector DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field( @@ -154,6 +165,16 @@ class DatabaseConfig(BaseSettings): default="", ) + DB_SESSION_TIMEZONE_OVERRIDE: str = Field( + description=( + "PostgreSQL session timezone override injected via startup options." + " Default is 'UTC' for out-of-the-box consistency." + " Set to empty string to disable app-level timezone injection, for example when using RDS Proxy" + " together with a database-side default timezone." + ), + default="UTC", + ) + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str: @@ -214,21 +235,22 @@ class DatabaseConfig(BaseSettings): @computed_field # type: ignore[prop-decorator] @property - def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: + def SQLALCHEMY_ENGINE_OPTIONS(self) -> SQLAlchemyEngineOptionsDict: # Parse DB_EXTRAS for 'options' db_extras_dict = dict(parse_qsl(self.DB_EXTRAS)) options = db_extras_dict.get("options", "") - connect_args = {} + connect_args: dict[str, str] = {} # Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"): - timezone_opt = "-c timezone=UTC" - if options: - merged_options = f"{options} {timezone_opt}" - else: - merged_options = timezone_opt - connect_args = {"options": merged_options} + merged_options = options.strip() + session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip() + if session_timezone_override: + timezone_opt = f"-c timezone={session_timezone_override}" + merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt + if merged_options: + connect_args = {"options": merged_options} - return { + result: SQLAlchemyEngineOptionsDict = { "pool_size": self.SQLALCHEMY_POOL_SIZE, "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, @@ -238,6 +260,7 @@ class DatabaseConfig(BaseSettings): "pool_reset_on_return": None, "pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT, } + return result class CeleryConfig(DatabaseConfig): diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index b49275758a..2def0a0d4e 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -32,6 +32,11 @@ class RedisConfig(BaseSettings): default=0, ) + REDIS_KEY_PREFIX: str = Field( + description="Optional global prefix for Redis keys, topics, and transport artifacts", + default="", + ) + REDIS_USE_SSL: bool = Field( description="Enable SSL/TLS for the Redis connection", default=False, diff --git a/api/configs/middleware/vdb/hologres_config.py b/api/configs/middleware/vdb/hologres_config.py index 9812cce268..788b3cfb78 100644 --- a/api/configs/middleware/vdb/hologres_config.py +++ b/api/configs/middleware/vdb/hologres_config.py @@ -1,4 +1,3 @@ -from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType from pydantic import Field from pydantic_settings import BaseSettings @@ -42,17 +41,17 @@ class HologresConfig(BaseSettings): default="public", ) - HOLOGRES_TOKENIZER: TokenizerType = Field( + HOLOGRES_TOKENIZER: str = Field( description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').", default="jieba", ) - HOLOGRES_DISTANCE_METHOD: DistanceType = Field( + HOLOGRES_DISTANCE_METHOD: str = Field( description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').", default="Cosine", ) - HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field( + HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field( description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').", default="rabitq", ) diff --git a/api/configs/middleware/vdb/iris_config.py b/api/configs/middleware/vdb/iris_config.py index c532d191c3..f5993dd8f8 100644 --- a/api/configs/middleware/vdb/iris_config.py +++ b/api/configs/middleware/vdb/iris_config.py @@ -1,5 +1,7 @@ """Configuration for InterSystems IRIS vector database.""" +from typing import Any + from pydantic import Field, PositiveInt, model_validator from pydantic_settings import BaseSettings @@ -64,7 +66,7 @@ class IrisVectorConfig(BaseSettings): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]: """Validate IRIS configuration values. Args: diff --git a/api/constants/dsl_version.py b/api/constants/dsl_version.py new file mode 100644 index 0000000000..b0fbe0075c --- /dev/null +++ b/api/constants/dsl_version.py @@ -0,0 +1 @@ +CURRENT_APP_DSL_VERSION = "0.6.0" diff --git a/api/controllers/common/controller_schemas.py b/api/controllers/common/controller_schemas.py index 39e3b5857d..c12d576473 100644 --- a/api/controllers/common/controller_schemas.py +++ b/api/controllers/common/controller_schemas.py @@ -1,4 +1,5 @@ from typing import Any, Literal +from uuid import UUID from pydantic import BaseModel, Field, model_validator @@ -23,9 +24,9 @@ class ConversationRenamePayload(BaseModel): class MessageListQuery(BaseModel): - conversation_id: UUIDStrOrEmpty - first_id: UUIDStrOrEmpty | None = None - limit: int = Field(default=20, ge=1, le=100) + conversation_id: UUIDStrOrEmpty = Field(description="Conversation UUID") + first_id: UUIDStrOrEmpty | None = Field(default=None, description="First message ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)") class MessageFeedbackPayload(BaseModel): @@ -69,11 +70,35 @@ class WorkflowUpdatePayload(BaseModel): marked_comment: str | None = Field(default=None, max_length=100) +# --- Dataset schemas --- + + +DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100 + + +class ChildChunkCreatePayload(BaseModel): + content: str + + +class ChildChunkUpdatePayload(BaseModel): + content: str + + +class DocumentBatchDownloadZipPayload(BaseModel): + """Request payload for bulk downloading documents as a zip archive.""" + + document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS) + + +class MetadataUpdatePayload(BaseModel): + name: str + + # --- Audio schemas --- class TextToAudioPayload(BaseModel): - message_id: str | None = None - voice: str | None = None - text: str | None = None - streaming: bool | None = None + message_id: str | None = Field(default=None, description="Message ID") + voice: str | None = Field(default=None, description="Voice to use for TTS") + text: str | None = Field(default=None, description="Text to convert to audio") + streaming: bool | None = Field(default=None, description="Enable streaming response") diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 4fe3fc9062..8e665c1386 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -2,9 +2,9 @@ from __future__ import annotations from typing import Any -from graphon.file import helpers as file_helpers from pydantic import BaseModel, ConfigDict, computed_field +from graphon.file import helpers as file_helpers from models.model import IconType type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index ef89e66980..84903733b5 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -41,7 +41,8 @@ def guess_file_info_from_response(response: httpx.Response): # Try to extract filename from URL parsed_url = urllib.parse.urlparse(url) url_path = parsed_url.path - filename = os.path.basename(url_path) + # Decode percent-encoded characters in the path segment + filename = urllib.parse.unquote(os.path.basename(url_path)) # If filename couldn't be extracted, use Content-Disposition header if not filename: diff --git a/api/controllers/common/human_input.py b/api/controllers/common/human_input.py new file mode 100644 index 0000000000..5d6f4efb95 --- /dev/null +++ b/api/controllers/common/human_input.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, JsonValue + + +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict[str, JsonValue] + action: str diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index d624b10b22..980e828945 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -65,6 +65,7 @@ from .app import ( statistic, workflow, workflow_app_log, + workflow_comment, workflow_draft_variable, workflow_run, workflow_statistic, @@ -116,6 +117,7 @@ from .explore import ( saved_message, trial, ) +from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport] # Import tag controllers from .tag import tags @@ -201,6 +203,7 @@ __all__ = [ "saved_message", "setup", "site", + "socketio_workflow", "spec", "statistic", "tags", @@ -211,6 +214,7 @@ __all__ = [ "website", "workflow", "workflow_app_log", + "workflow_comment", "workflow_draft_variable", "workflow_run", "workflow_statistic", diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 772bb9d0f1..b03d9b4a4c 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,12 +1,16 @@ +from datetime import datetime + import flask_restx -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from flask_restx._http import HTTPStatus +from pydantic import field_validator from sqlalchemy import delete, func, select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from extensions.ext_database import db -from libs.helper import TimestampField +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.dataset import Dataset from models.enums import ApiTokenType @@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache from . import console_ns from .wraps import account_initialization_required, edit_permission_required, setup_required -api_key_fields = { - "id": fields.String, - "type": fields.String, - "token": fields.String, - "last_used_at": TimestampField, - "created_at": TimestampField, -} -api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} -api_key_list_model = console_ns.model( - "ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")} -) +class ApiKeyItem(ResponseModel): + id: str + type: str + token: str + last_used_at: int | None = None + created_at: int | None = None + + @field_validator("last_used_at", "created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class ApiKeyList(ResponseModel): + data: list[ApiKeyItem] + + +register_schema_models(console_ns, ApiKeyItem, ApiKeyList) def _get_resource(resource_id, tenant_id, resource_model): @@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource): token_prefix: str | None = None max_keys = 10 - @marshal_with(api_key_list_model) def get(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" resource_id = str(resource_id) @@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource): ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id ) ).all() - return {"items": keys} + return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json") - @marshal_with(api_key_item_model) @edit_permission_required def post(self, resource_id): assert self.resource_id_field is not None, "resource_id_field must be set" @@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource): api_token.type = self.resource_type db.session.add(api_token) db.session.commit() - return api_token, 201 + return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201 class BaseApiKeyResource(Resource): @@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("get_app_api_keys") @console_ns.doc(description="Get all API keys for an app") @console_ns.doc(params={"resource_id": "App ID"}) - @console_ns.response(200, "Success", api_key_list_model) + @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) def get(self, resource_id): # type: ignore """Get all API keys for an app""" return super().get(resource_id) @@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("create_app_api_key") @console_ns.doc(description="Create a new API key for an app") @console_ns.doc(params={"resource_id": "App ID"}) - @console_ns.response(201, "API key created successfully", api_key_item_model) + @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") def post(self, resource_id): # type: ignore """Create a new API key for an app""" @@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get all API keys for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) - @console_ns.response(200, "Success", api_key_list_model) + @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) def get(self, resource_id): # type: ignore """Get all API keys for a dataset""" return super().get(resource_id) @@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource): @console_ns.doc("create_dataset_api_key") @console_ns.doc(description="Create a new API key for a dataset") @console_ns.doc(params={"resource_id": "Dataset ID"}) - @console_ns.response(201, "API key created successfully", api_key_item_model) + @console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) @console_ns.response(400, "Maximum keys exceeded") def post(self, resource_id): # type: ignore """Create a new API key for a dataset""" diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index 3bd61feb44..ed66da1be5 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required -from services.advanced_prompt_template_service import AdvancedPromptTemplateService +from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService class AdvancedPromptTemplateQuery(BaseModel): @@ -35,5 +35,10 @@ class AdvancedPromptTemplateList(Resource): @account_initialization_required def get(self): args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - - return AdvancedPromptTemplateService.get_prompt(args.model_dump()) + prompt_args: AdvancedPromptTemplateArgs = { + "app_mode": args.app_mode, + "model_mode": args.model_mode, + "model_name": args.model_name, + "has_context": args.has_context, + } + return AdvancedPromptTemplateService.get_prompt(prompt_args) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 9931bb5dd7..528785931e 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -25,7 +25,13 @@ from fields.annotation_fields import ( ) from libs.helper import uuid_value from libs.login import login_required -from services.annotation_service import AppAnnotationService +from services.annotation_service import ( + AppAnnotationService, + EnableAnnotationArgs, + UpdateAnnotationArgs, + UpdateAnnotationSettingArgs, + UpsertAnnotationArgs, +) DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource): args = AnnotationReplyPayload.model_validate(console_ns.payload) match action: case "enable": - result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id) + enable_args: EnableAnnotationArgs = { + "score_threshold": args.score_threshold, + "embedding_provider_name": args.embedding_provider_name, + "embedding_model_name": args.embedding_model_name, + } + result = AppAnnotationService.enable_app_annotation(enable_args, app_id) case "disable": result = AppAnnotationService.disable_app_annotation(app_id) return result, 200 @@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource): args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload) - result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump()) + setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold} + result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args) return result, 200 @@ -237,8 +249,16 @@ class AnnotationApi(Resource): def post(self, app_id): app_id = str(app_id) args = CreateAnnotationPayload.model_validate(console_ns.payload) - data = args.model_dump(exclude_none=True) - annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id) + upsert_args: UpsertAnnotationArgs = {} + if args.answer is not None: + upsert_args["answer"] = args.answer + if args.content is not None: + upsert_args["content"] = args.content + if args.message_id is not None: + upsert_args["message_id"] = args.message_id + if args.question is not None: + upsert_args["question"] = args.question + annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) args = UpdateAnnotationPayload.model_validate(console_ns.payload) - annotation = AppAnnotationService.update_app_annotation_directly( - args.model_dump(exclude_none=True), app_id, annotation_id - ) + update_args: UpdateAnnotationArgs = {} + if args.answer is not None: + update_args["answer"] = args.answer + if args.question is not None: + update_args["question"] = args.question + annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 2018f60215..c8334bfd18 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,15 +1,15 @@ import logging +import re import uuid from datetime import datetime from typing import Any, Literal from flask import request from flask_restx import Resource -from graphon.enums import WorkflowExecutionStatus -from graphon.file import helpers as file_helpers from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session +from werkzeug.datastructures import MultiDict from werkzeug.exceptions import BadRequest from controllers.common.helpers import FileInfo @@ -31,13 +31,15 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES from extensions.ext_database import db from fields.base import ResponseModel +from graphon.enums import WorkflowExecutionStatus +from libs.helper import build_icon_url from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow from models.model import IconType from services.app_dsl_service import AppDslService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService -from services.entities.dsl_entities import ImportMode +from services.entities.dsl_entities import ImportMode, ImportStatus from services.entities.knowledge_entities.knowledge_entities import ( DataSource, InfoList, @@ -57,6 +59,7 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co register_enum_models(console_ns, IconType) _logger = logging.getLogger(__name__) +_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$") class AppListQuery(BaseModel): @@ -66,22 +69,19 @@ class AppListQuery(BaseModel): default="all", description="App mode filter" ) name: str | None = Field(default=None, description="Filter by app name") - tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs") + tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs") is_created_by_me: bool | None = Field(default=None, description="Filter by creator") @field_validator("tag_ids", mode="before") @classmethod - def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None: + def validate_tag_ids(cls, value: list[str] | None) -> list[str] | None: if not value: return None - if isinstance(value, str): - items = [item.strip() for item in value.split(",") if item.strip()] - elif isinstance(value, list): - items = [str(item).strip() for item in value if item and str(item).strip()] - else: - raise TypeError("Unsupported tag_ids type.") + if not isinstance(value, list): + raise ValueError("Unsupported tag_ids type.") + items = [str(item).strip() for item in value if item and str(item).strip()] if not items: return None @@ -91,6 +91,26 @@ class AppListQuery(BaseModel): raise ValueError("Invalid UUID format in tag_ids.") from exc +def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]: + normalized: dict[str, str | list[str]] = {} + indexed_tag_ids: list[tuple[int, str]] = [] + + for key in query_args: + match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key) + if match: + indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key)) + continue + + value = query_args.get(key) + if value is not None: + normalized[key] = value + + if indexed_tag_ids: + normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)] + + return normalized + + class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) @@ -129,6 +149,7 @@ class AppNamePayload(BaseModel): class AppIconPayload(BaseModel): icon: str | None = Field(default=None, description="Icon data") + icon_type: IconType | None = Field(default=None, description="Icon type") icon_background: str | None = Field(default=None, description="Icon background color") @@ -161,15 +182,6 @@ def _to_timestamp(value: datetime | int | None) -> int | None: return value -def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: - if icon is None or icon_type is None: - return None - icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) - if icon_type_value.lower() != IconType.IMAGE: - return None - return file_helpers.get_signed_file_url(icon) - - class Tag(ResponseModel): id: str name: str @@ -292,7 +304,7 @@ class Site(ResponseModel): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) @field_validator("icon_type", mode="before") @classmethod @@ -342,7 +354,7 @@ class AppPartial(ResponseModel): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) @field_validator("created_at", "updated_at", mode="before") @classmethod @@ -390,7 +402,7 @@ class AppDetailWithSite(AppDetail): @computed_field(return_type=str | None) # type: ignore @property def icon_url(self) -> str | None: - return _build_icon_url(self.icon_type, self.icon) + return build_icon_url(self.icon_type, self.icon) class AppPagination(ResponseModel): @@ -463,7 +475,7 @@ class AppListApi(Resource): """Get app list""" current_user, current_tenant_id = current_account_with_tenant() - args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args)) args_dict = args.model_dump() # get app list @@ -632,7 +644,7 @@ class AppCopyApi(Resource): args = CopyAppPayload.model_validate(console_ns.payload or {}) - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) result = import_service.import_app( @@ -645,6 +657,13 @@ class AppCopyApi(Resource): icon=args.icon, icon_background=args.icon_background, ) + if result.status == ImportStatus.FAILED: + session.rollback() + return result.model_dump(mode="json"), 400 + if result.status == ImportStatus.PENDING: + session.rollback() + return result.model_dump(mode="json"), 202 + session.commit() # Inherit web app permission from original app if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: @@ -693,6 +712,32 @@ class AppExportApi(Resource): return payload.model_dump(mode="json") +@console_ns.route("/apps//publish-to-creators-platform") +class AppPublishToCreatorsPlatformApi(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=None) + @edit_permission_required + def post(self, app_model): + """Publish app to Creators Platform""" + from configs import dify_config + from core.helper.creators import get_redirect_url, upload_dsl + + if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED: + return {"error": "Creators Platform features are not enabled"}, 403 + + current_user, _ = current_account_with_tenant() + + dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False) + dsl_bytes = dsl_content.encode("utf-8") + + claim_code = upload_dsl(dsl_bytes) + redirect_url = get_redirect_url(str(current_user.id), claim_code) + + return {"redirect_url": redirect_url} + + @console_ns.route("/apps//name") class AppNameApi(Resource): @console_ns.doc("check_app_name") @@ -731,7 +776,12 @@ class AppIconApi(Resource): args = AppIconPayload.model_validate(console_ns.payload or {}) app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") + app_model = app_service.update_app_icon( + app_model, + args.icon or "", + args.icon_background or "", + args.icon_type, + ) response_model = AppDetail.model_validate(app_model, from_attributes=True) return response_model.model_dump(mode="json") diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 12d6951a48..e91dc9cfe5 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,7 +1,8 @@ -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session +from controllers.common.schema import register_schema_models from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( account_initialization_required, @@ -10,35 +11,15 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.app_fields import ( - app_import_check_dependencies_fields, - app_import_fields, - leaked_dependency_fields, -) from libs.login import current_account_with_tenant, login_required from models.model import App -from services.app_dsl_service import AppDslService +from services.app_dsl_service import AppDslService, Import from services.enterprise.enterprise_service import EnterpriseService -from services.entities.dsl_entities import ImportStatus +from services.entities.dsl_entities import CheckDependenciesResult, ImportStatus from services.feature_service import FeatureService from .. import console_ns -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base model first -leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields) - -app_import_model = console_ns.model("AppImport", app_import_fields) - -# For nested models, need to replace nested dict with registered model -app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy() -app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model)) -app_import_check_dependencies_model = console_ns.model( - "AppImportCheckDependencies", app_import_check_dependencies_fields_copy -) - -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AppImportPayload(BaseModel): mode: str = Field(..., description="Import mode") @@ -52,18 +33,18 @@ class AppImportPayload(BaseModel): app_id: str | None = Field(None) -console_ns.schema_model( - AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult) @console_ns.route("/apps/imports") class AppImportApi(Resource): @console_ns.expect(console_ns.models[AppImportPayload.__name__]) + @console_ns.response(200, "Import completed", console_ns.models[Import.__name__]) + @console_ns.response(202, "Import pending confirmation", console_ns.models[Import.__name__]) + @console_ns.response(400, "Import failed", console_ns.models[Import.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(app_import_model) @cloud_edition_billing_resource_check("apps") @edit_permission_required def post(self): @@ -71,8 +52,9 @@ class AppImportApi(Resource): current_user, _ = current_account_with_tenant() args = AppImportPayload.model_validate(console_ns.payload) - # Create service with session - with sessionmaker(db.engine).begin() as session: + # AppDslService performs internal commits for some creation paths, so use a plain + # Session here instead of nesting it inside sessionmaker(...).begin(). + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Import app account = current_user @@ -88,6 +70,10 @@ class AppImportApi(Resource): icon_background=args.icon_background, app_id=args.app_id, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: # update web app setting as private EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") @@ -104,21 +90,25 @@ class AppImportApi(Resource): @console_ns.route("/apps/imports//confirm") class AppImportConfirmApi(Resource): + @console_ns.response(200, "Import confirmed", console_ns.models[Import.__name__]) + @console_ns.response(400, "Import failed", console_ns.models[Import.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(app_import_model) @edit_permission_required def post(self, import_id): # Check user role first current_user, _ = current_account_with_tenant() - # Create service with session - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Confirm import account = current_user result = import_service.confirm_import(import_id=import_id, account=account) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() # Return appropriate status code based on result if result.status == ImportStatus.FAILED: @@ -128,14 +118,14 @@ class AppImportConfirmApi(Resource): @console_ns.route("/apps/imports//check-dependencies") class AppImportCheckDependenciesApi(Resource): + @console_ns.response(200, "Dependencies checked", console_ns.models[CheckDependenciesResult.__name__]) @setup_required @login_required @get_app_model @account_initialization_required - @marshal_with(app_import_check_dependencies_model) @edit_permission_required def get(self, app_model: App): - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) result = import_service.check_dependencies(app_model=app_model) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 78ddb904e1..91fbe4a85a 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -2,7 +2,6 @@ import logging from flask import request from flask_restx import Resource, fields -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -23,6 +22,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index d83925d173..fe274e4c9a 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -3,7 +3,6 @@ from typing import Any, Literal from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -27,6 +26,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index d329d22309..b2b1049f0c 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -2,20 +2,37 @@ from typing import Literal import sqlalchemy as sa from flask import abort, request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, or_ from sqlalchemy.orm import selectinload from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db -from fields.raws import FilesContainedField +from fields.conversation_fields import ( + Conversation as ConversationResponse, +) +from fields.conversation_fields import ( + ConversationDetail as ConversationDetailResponse, +) +from fields.conversation_fields import ( + ConversationMessageDetail as ConversationMessageDetailResponse, +) +from fields.conversation_fields import ( + ConversationPagination as ConversationPaginationResponse, +) +from fields.conversation_fields import ( + ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse, +) +from fields.conversation_fields import ( + ResultResponse, +) from libs.datetime_utils import naive_utc_now, parse_time_range -from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models import Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode @@ -62,267 +79,16 @@ console_ns.schema_model( ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) -# Register models for flask_restx to avoid dict type issues in Swagger -# Register in dependency order: base models first, then dependent models - -# Base models -simple_account_model = console_ns.model( - "SimpleAccount", - { - "id": fields.String, - "name": fields.String, - "email": fields.String, - }, -) - -feedback_stat_model = console_ns.model( - "FeedbackStat", - { - "like": fields.Integer, - "dislike": fields.Integer, - }, -) - -status_count_model = console_ns.model( - "StatusCount", - { - "success": fields.Integer, - "failed": fields.Integer, - "partial_success": fields.Integer, - "paused": fields.Integer, - }, -) - -message_file_model = console_ns.model( - "MessageFile", - { - "id": fields.String, - "filename": fields.String, - "type": fields.String, - "url": fields.String, - "mime_type": fields.String, - "size": fields.Integer, - "transfer_method": fields.String, - "belongs_to": fields.String(default="user"), - "upload_file_id": fields.String(default=None), - }, -) - -agent_thought_model = console_ns.model( - "AgentThought", - { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), - }, -) - -simple_model_config_model = console_ns.model( - "SimpleModelConfig", - { - "model": fields.Raw(attribute="model_dict"), - "pre_prompt": fields.String, - }, -) - -model_config_model = console_ns.model( - "ModelConfig", - { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "model": fields.Raw, - "user_input_form": fields.Raw, - "pre_prompt": fields.String, - "agent_mode": fields.Raw, - }, -) - -# Models that depend on simple_account_model -feedback_model = console_ns.model( - "Feedback", - { - "rating": fields.String, - "content": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account": fields.Nested(simple_account_model, allow_null=True), - }, -) - -annotation_model = console_ns.model( - "Annotation", - { - "id": fields.String, - "question": fields.String, - "content": fields.String, - "account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -annotation_hit_history_model = console_ns.model( - "AnnotationHitHistory", - { - "annotation_id": fields.String(attribute="id"), - "annotation_create_account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - - -class MessageTextField(fields.Raw): - def format(self, value): - return value[0]["text"] if value else "" - - -# Simple message detail model -simple_message_detail_model = console_ns.model( - "SimpleMessageDetail", - { - "inputs": FilesContainedField, - "query": fields.String, - "message": MessageTextField, - "answer": fields.String, - }, -) - -# Message detail model that depends on multiple models -message_detail_model = console_ns.model( - "MessageDetail", - { - "id": fields.String, - "conversation_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "message": fields.Raw, - "message_tokens": fields.Integer, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "answer_tokens": fields.Integer, - "provider_response_latency": fields.Float, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "feedbacks": fields.List(fields.Nested(feedback_model)), - "workflow_run_id": fields.String, - "annotation": fields.Nested(annotation_model, allow_null=True), - "annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), - "message_files": fields.List(fields.Nested(message_file_model)), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - "parent_message_id": fields.String, - }, -) - -# Conversation models -conversation_fields_model = console_ns.model( - "Conversation", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_end_user_session_id": fields.String(), - "from_account_id": fields.String, - "from_account_name": fields.String, - "read_at": TimestampField, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotation": fields.Nested(annotation_model, allow_null=True), - "model_config": fields.Nested(simple_model_config_model), - "user_feedback_stats": fields.Nested(feedback_stat_model), - "admin_feedback_stats": fields.Nested(feedback_stat_model), - "message": fields.Nested(simple_message_detail_model, attribute="first_message"), - }, -) - -conversation_pagination_model = console_ns.model( - "ConversationPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(conversation_fields_model), attribute="items"), - }, -) - -conversation_message_detail_model = console_ns.model( - "ConversationMessageDetail", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "created_at": TimestampField, - "model_config": fields.Nested(model_config_model), - "message": fields.Nested(message_detail_model, attribute="first_message"), - }, -) - -conversation_with_summary_model = console_ns.model( - "ConversationWithSummary", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_end_user_session_id": fields.String, - "from_account_id": fields.String, - "from_account_name": fields.String, - "name": fields.String, - "summary": fields.String(attribute="summary_or_query"), - "read_at": TimestampField, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotated": fields.Boolean, - "model_config": fields.Nested(simple_model_config_model), - "message_count": fields.Integer, - "user_feedback_stats": fields.Nested(feedback_stat_model), - "admin_feedback_stats": fields.Nested(feedback_stat_model), - "status_count": fields.Nested(status_count_model), - }, -) - -conversation_with_summary_pagination_model = console_ns.model( - "ConversationWithSummaryPagination", - { - "page": fields.Integer, - "limit": fields.Integer(attribute="per_page"), - "total": fields.Integer, - "has_more": fields.Boolean(attribute="has_next"), - "data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"), - }, -) - -conversation_detail_model = console_ns.model( - "ConversationDetail", - { - "id": fields.String, - "status": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "created_at": TimestampField, - "updated_at": TimestampField, - "annotated": fields.Boolean, - "introduction": fields.String, - "model_config": fields.Nested(model_config_model), - "message_count": fields.Integer, - "user_feedback_stats": fields.Nested(feedback_stat_model), - "admin_feedback_stats": fields.Nested(feedback_stat_model), - }, +register_schema_models( + console_ns, + CompletionConversationQuery, + ChatConversationQuery, + ConversationResponse, + ConversationPaginationResponse, + ConversationMessageDetailResponse, + ConversationWithSummaryPaginationResponse, + ConversationDetailResponse, + ResultResponse, ) @@ -332,13 +98,12 @@ class CompletionConversationApi(Resource): @console_ns.doc(description="Get completion conversations with pagination and filtering") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[CompletionConversationQuery.__name__]) - @console_ns.response(200, "Success", conversation_pagination_model) + @console_ns.response(200, "Success", console_ns.models[ConversationPaginationResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - @marshal_with(conversation_pagination_model) @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() @@ -394,7 +159,9 @@ class CompletionConversationApi(Resource): conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False) - return conversations + return ConversationPaginationResponse.model_validate(conversations, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//completion-conversations/") @@ -402,19 +169,19 @@ class CompletionConversationDetailApi(Resource): @console_ns.doc("get_completion_conversation") @console_ns.doc(description="Get completion conversation details with messages") @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) - @console_ns.response(200, "Success", conversation_message_detail_model) + @console_ns.response(200, "Success", console_ns.models[ConversationMessageDetailResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) - @marshal_with(conversation_message_detail_model) @edit_permission_required def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - - return _get_conversation(app_model, conversation_id) + return ConversationMessageDetailResponse.model_validate( + _get_conversation(app_model, conversation_id), from_attributes=True + ).model_dump(mode="json") @console_ns.doc("delete_completion_conversation") @console_ns.doc(description="Delete a completion conversation") @@ -436,7 +203,7 @@ class CompletionConversationDetailApi(Resource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 @console_ns.route("/apps//chat-conversations") @@ -445,13 +212,12 @@ class ChatConversationApi(Resource): @console_ns.doc(description="Get chat conversations with pagination, filtering and summary") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ChatConversationQuery.__name__]) - @console_ns.response(200, "Success", conversation_with_summary_pagination_model) + @console_ns.response(200, "Success", console_ns.models[ConversationWithSummaryPaginationResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - @marshal_with(conversation_with_summary_pagination_model) @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() @@ -546,7 +312,9 @@ class ChatConversationApi(Resource): conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False) - return conversations + return ConversationWithSummaryPaginationResponse.model_validate(conversations, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//chat-conversations/") @@ -554,19 +322,19 @@ class ChatConversationDetailApi(Resource): @console_ns.doc("get_chat_conversation") @console_ns.doc(description="Get chat conversation details") @console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) - @console_ns.response(200, "Success", conversation_detail_model) + @console_ns.response(200, "Success", console_ns.models[ConversationDetailResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Conversation not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - @marshal_with(conversation_detail_model) @edit_permission_required def get(self, app_model, conversation_id): conversation_id = str(conversation_id) - - return _get_conversation(app_model, conversation_id) + return ConversationDetailResponse.model_validate( + _get_conversation(app_model, conversation_id), from_attributes=True + ).model_dump(mode="json") @console_ns.doc("delete_chat_conversation") @console_ns.doc(description="Delete a chat conversation") @@ -588,7 +356,7 @@ class ChatConversationDetailApi(Resource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 204 + return ResultResponse(result="success").model_dump(mode="json"), 204 def _get_conversation(app_model, conversation_id): diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 369c26a80c..9c8b095b9f 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -1,44 +1,86 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from fields.conversation_variable_fields import ( - conversation_variable_fields, - paginated_conversation_variable_fields, -) +from fields._value_type_serializer import serialize_value_type +from fields.base import ResponseModel from libs.login import login_required from models import ConversationVariable from models.model import AppMode -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ConversationVariablesQuery(BaseModel): conversation_id: str = Field(..., description="Conversation ID to filter variables") -console_ns.schema_model( - ConversationVariablesQuery.__name__, - ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -# Register models for flask_restx to avoid dict type issues in Swagger -# Register base model first -conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields) -# For nested models, need to replace nested dict with registered model -paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy() -paginated_conversation_variable_fields_copy["data"] = fields.List( - fields.Nested(conversation_variable_model), attribute="data" -) -paginated_conversation_variable_model = console_ns.model( - "PaginatedConversationVariable", paginated_conversation_variable_fields_copy +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def _normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type()) + if isinstance(value, str): + return value + try: + return serialize_value_type(value) + except Exception: + return serialize_value_type({"value_type": value}) + + @field_validator("value", mode="before") + @classmethod + def _normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class PaginatedConversationVariableResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationVariableResponse] + + +register_schema_models( + console_ns, + ConversationVariablesQuery, + ConversationVariableResponse, + PaginatedConversationVariableResponse, ) @@ -48,12 +90,15 @@ class ConversationVariablesApi(Resource): @console_ns.doc(description="Get conversation variables for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__]) - @console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model) + @console_ns.response( + 200, + "Conversation variables retrieved successfully", + console_ns.models[PaginatedConversationVariableResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.ADVANCED_CHAT) - @marshal_with(paginated_conversation_variable_model) def get(self, app_model): args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -72,17 +117,22 @@ class ConversationVariablesApi(Resource): with sessionmaker(db.engine, expire_on_commit=False).begin() as session: rows = session.scalars(stmt).all() - return { - "page": page, - "limit": page_size, - "total": len(rows), - "has_more": False, - "data": [ - { - "created_at": row.created_at, - "updated_at": row.updated_at, - **row.to_variable().model_dump(), - } - for row in rows - ], - } + response = PaginatedConversationVariableResponse.model_validate( + { + "page": page, + "limit": page_size, + "total": len(rows), + "has_more": False, + "data": [ + ConversationVariableResponse.model_validate( + { + "created_at": row.created_at, + "updated_at": row.updated_at, + **row.to_variable().model_dump(), + } + ) + for row in rows + ], + } + ) + return response.model_dump(mode="json") diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 7101d5df7b..c720a5e074 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from controllers.console import console_ns @@ -20,6 +19,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_account_with_tenant, login_required from models import App from services.workflow_service import WorkflowService diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 412fc8795a..d517f695b8 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -1,39 +1,68 @@ import json +from datetime import datetime +from typing import Any -from flask_restx import Resource, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import NotFound +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from extensions.ext_database import db -from fields.app_fields import app_server_fields +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.enums import AppMCPServerStatus from models.model import AppMCPServer -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - -# Register model for flask_restx to avoid dict type issues in Swagger -app_server_model = console_ns.model("AppServer", app_server_fields) - class MCPServerCreatePayload(BaseModel): description: str | None = Field(default=None, description="Server description") - parameters: dict = Field(..., description="Server parameters configuration") + parameters: dict[str, Any] = Field(..., description="Server parameters configuration") class MCPServerUpdatePayload(BaseModel): id: str = Field(..., description="Server ID") description: str | None = Field(default=None, description="Server description") - parameters: dict = Field(..., description="Server parameters configuration") + parameters: dict[str, Any] = Field(..., description="Server parameters configuration") status: str | None = Field(default=None, description="Server status") -for model in (MCPServerCreatePayload, MCPServerUpdatePayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class AppMCPServerResponse(ResponseModel): + id: str + name: str + server_code: str + description: str + status: AppMCPServerStatus + parameters: dict[str, Any] | list[Any] | str + created_at: int | None = None + updated_at: int | None = None + + @field_validator("parameters", mode="before") + @classmethod + def _normalize_parameters(cls, value: Any) -> Any: + if isinstance(value, str): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return value + return value + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse) @console_ns.route("/apps//server") @@ -41,27 +70,31 @@ class AppMCPServerController(Resource): @console_ns.doc("get_app_mcp_server") @console_ns.doc(description="Get MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model) + @console_ns.response( + 200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @login_required @account_initialization_required @setup_required @get_app_model - @marshal_with(app_server_model) def get(self, app_model): server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1)) - return server + if server is None: + return {} + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json") @console_ns.doc("create_app_mcp_server") @console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__]) - @console_ns.response(201, "MCP server configuration created successfully", app_server_model) + @console_ns.response( + 201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @console_ns.response(403, "Insufficient permissions") @account_initialization_required @get_app_model @login_required @setup_required - @marshal_with(app_server_model) @edit_permission_required def post(self, app_model): _, current_tenant_id = current_account_with_tenant() @@ -82,20 +115,21 @@ class AppMCPServerController(Resource): ) db.session.add(server) db.session.commit() - return server + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201 @console_ns.doc("update_app_mcp_server") @console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__]) - @console_ns.response(200, "MCP server configuration updated successfully", app_server_model) + @console_ns.response( + 200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__] + ) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @get_app_model @login_required @setup_required @account_initialization_required - @marshal_with(app_server_model) @edit_permission_required def put(self, app_model): payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {}) @@ -118,7 +152,7 @@ class AppMCPServerController(Resource): except ValueError: raise ValueError("Invalid status") db.session.commit() - return server + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//server/refresh") @@ -126,13 +160,12 @@ class AppMCPServerRefreshController(Resource): @console_ns.doc("refresh_app_mcp_server") @console_ns.doc(description="Refresh MCP server configuration and regenerate server code") @console_ns.doc(params={"server_id": "Server ID"}) - @console_ns.response(200, "MCP server refreshed successfully", app_server_model) + @console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "Server not found") @setup_required @login_required @account_initialization_required - @marshal_with(app_server_model) @edit_permission_required def get(self, server_id): _, current_tenant_id = current_account_with_tenant() @@ -145,4 +178,4 @@ class AppMCPServerRefreshController(Resource): raise NotFound() server.server_code = AppMCPServer.generate_server_code(16) db.session.commit() - return server + return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 5a19544eab..44e19b57db 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,9 +1,9 @@ import logging +from datetime import datetime from typing import Literal from flask import request -from flask_restx import Resource, fields, marshal_with -from graphon.model_runtime.errors.invoke import InvokeError +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound @@ -25,10 +25,22 @@ from controllers.console.wraps import ( setup_required, ) from core.app.entities.app_invoke_entities import InvokeFrom +from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from extensions.ext_database import db -from fields.raws import FilesContainedField -from libs.helper import TimestampField, uuid_value +from fields.base import ResponseModel +from fields.conversation_fields import ( + AgentThought, + ConversationAnnotation, + ConversationAnnotationHitHistory, + Feedback, + JSONValue, + MessageFile, + format_files_contained, + to_timestamp, +) +from graphon.model_runtime.errors.invoke import InvokeError +from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required from models.enums import FeedbackFromSource, FeedbackRating @@ -98,6 +110,51 @@ class SuggestedQuestionsResponse(BaseModel): data: list[str] = Field(description="Suggested question") +class MessageDetailResponse(ResponseModel): + id: str + conversation_id: str + inputs: dict[str, JSONValue] + query: str + message: JSONValue | None = None + message_tokens: int | None = None + answer: str = Field(validation_alias="re_sign_file_url_answer") + answer_tokens: int | None = None + provider_response_latency: float | None = None + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + feedbacks: list[Feedback] = Field(default_factory=list) + workflow_run_id: str | None = None + annotation: ConversationAnnotation | None = None + annotation_hit_history: ConversationAnnotationHitHistory | None = None + created_at: int | None = None + agent_thoughts: list[AgentThought] = Field(default_factory=list) + message_files: list[MessageFile] = Field(default_factory=list) + extra_contents: list[ExecutionExtraContentDomainModel] = Field(default_factory=list) + metadata: JSONValue | None = Field(default=None, validation_alias="message_metadata_dict") + status: str + error: str | None = None + parent_message_id: str | None = None + + @field_validator("inputs", mode="before") + @classmethod + def _normalize_inputs(cls, value: JSONValue) -> JSONValue: + return format_files_contained(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + + +class MessageInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[MessageDetailResponse] + + register_schema_models( console_ns, ChatMessagesQuery, @@ -105,124 +162,8 @@ register_schema_models( FeedbackExportQuery, AnnotationCountResponse, SuggestedQuestionsResponse, -) - -# Register models for flask_restx to avoid dict type issues in Swagger -# Register in dependency order: base models first, then dependent models - -# Base models -simple_account_model = console_ns.model( - "SimpleAccount", - { - "id": fields.String, - "name": fields.String, - "email": fields.String, - }, -) - -message_file_model = console_ns.model( - "MessageFile", - { - "id": fields.String, - "filename": fields.String, - "type": fields.String, - "url": fields.String, - "mime_type": fields.String, - "size": fields.Integer, - "transfer_method": fields.String, - "belongs_to": fields.String(default="user"), - "upload_file_id": fields.String(default=None), - }, -) - -agent_thought_model = console_ns.model( - "AgentThought", - { - "id": fields.String, - "chain_id": fields.String, - "message_id": fields.String, - "position": fields.Integer, - "thought": fields.String, - "tool": fields.String, - "tool_labels": fields.Raw, - "tool_input": fields.String, - "created_at": TimestampField, - "observation": fields.String, - "files": fields.List(fields.String), - }, -) - -# Models that depend on simple_account_model -feedback_model = console_ns.model( - "Feedback", - { - "rating": fields.String, - "content": fields.String, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account": fields.Nested(simple_account_model, allow_null=True), - }, -) - -annotation_model = console_ns.model( - "Annotation", - { - "id": fields.String, - "question": fields.String, - "content": fields.String, - "account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -annotation_hit_history_model = console_ns.model( - "AnnotationHitHistory", - { - "annotation_id": fields.String(attribute="id"), - "annotation_create_account": fields.Nested(simple_account_model, allow_null=True), - "created_at": TimestampField, - }, -) - -# Message detail model that depends on multiple models -message_detail_model = console_ns.model( - "MessageDetail", - { - "id": fields.String, - "conversation_id": fields.String, - "inputs": FilesContainedField, - "query": fields.String, - "message": fields.Raw, - "message_tokens": fields.Integer, - "answer": fields.String(attribute="re_sign_file_url_answer"), - "answer_tokens": fields.Integer, - "provider_response_latency": fields.Float, - "from_source": fields.String, - "from_end_user_id": fields.String, - "from_account_id": fields.String, - "feedbacks": fields.List(fields.Nested(feedback_model)), - "workflow_run_id": fields.String, - "annotation": fields.Nested(annotation_model, allow_null=True), - "annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True), - "created_at": TimestampField, - "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), - "message_files": fields.List(fields.Nested(message_file_model)), - "extra_contents": fields.List(fields.Raw), - "metadata": fields.Raw(attribute="message_metadata_dict"), - "status": fields.String, - "error": fields.String, - "parent_message_id": fields.String, - }, -) - -# Message infinite scroll pagination model -message_infinite_scroll_pagination_model = console_ns.model( - "MessageInfiniteScrollPagination", - { - "limit": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(message_detail_model)), - }, + MessageDetailResponse, + MessageInfiniteScrollPaginationResponse, ) @@ -232,13 +173,12 @@ class ChatMessageListApi(Resource): @console_ns.doc(description="Get chat messages for a conversation with pagination") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[ChatMessagesQuery.__name__]) - @console_ns.response(200, "Success", message_infinite_scroll_pagination_model) + @console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPaginationResponse.__name__]) @console_ns.response(404, "Conversation not found") @login_required @account_initialization_required @setup_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) - @marshal_with(message_infinite_scroll_pagination_model) @edit_permission_required def get(self, app_model): args = ChatMessagesQuery.model_validate(request.args.to_dict()) @@ -298,7 +238,10 @@ class ChatMessageListApi(Resource): history_messages = list(reversed(history_messages)) attach_message_extra_contents(history_messages) - return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more) + return MessageInfiniteScrollPaginationResponse.model_validate( + InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more), + from_attributes=True, + ).model_dump(mode="json") @console_ns.route("/apps//feedbacks") @@ -468,13 +411,12 @@ class MessageApi(Resource): @console_ns.doc("get_message") @console_ns.doc(description="Get message details by ID") @console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) - @console_ns.response(200, "Message retrieved successfully", message_detail_model) + @console_ns.response(200, "Message retrieved successfully", console_ns.models[MessageDetailResponse.__name__]) @console_ns.response(404, "Message not found") @get_app_model @setup_required @login_required @account_initialization_required - @marshal_with(message_detail_model) def get(self, app_model, message_id: str): message_id = str(message_id) @@ -486,4 +428,4 @@ class MessageApi(Resource): raise NotFound("Message Not Exists.") attach_message_extra_contents([message]) - return message + return MessageDetailResponse.model_validate(message, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 8bb5aa2c1b..1869cbf5f6 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,9 +1,11 @@ import json -from typing import cast +from typing import Any, cast from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource +from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required @@ -18,30 +20,30 @@ from models.model import AppMode, AppModelConfig from services.app_model_config_service import AppModelConfigService +class ModelConfigRequest(BaseModel): + provider: str | None = Field(default=None, description="Model provider") + model: str | None = Field(default=None, description="Model name") + configs: dict[str, Any] | None = Field(default=None, description="Model configuration parameters") + opening_statement: str | None = Field(default=None, description="Opening statement") + suggested_questions: list[str] | None = Field(default=None, description="Suggested questions") + more_like_this: dict[str, Any] | None = Field(default=None, description="More like this configuration") + speech_to_text: dict[str, Any] | None = Field(default=None, description="Speech to text configuration") + text_to_speech: dict[str, Any] | None = Field(default=None, description="Text to speech configuration") + retrieval_model: dict[str, Any] | None = Field(default=None, description="Retrieval model configuration") + tools: list[dict[str, Any]] | None = Field(default=None, description="Available tools") + dataset_configs: dict[str, Any] | None = Field(default=None, description="Dataset configurations") + agent_mode: dict[str, Any] | None = Field(default=None, description="Agent mode configuration") + + +register_schema_models(console_ns, ModelConfigRequest) + + @console_ns.route("/apps//model-config") class ModelConfigResource(Resource): @console_ns.doc("update_app_model_config") @console_ns.doc(description="Update application model configuration") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.expect( - console_ns.model( - "ModelConfigRequest", - { - "provider": fields.String(description="Model provider"), - "model": fields.String(description="Model name"), - "configs": fields.Raw(description="Model configuration parameters"), - "opening_statement": fields.String(description="Opening statement"), - "suggested_questions": fields.List(fields.String(), description="Suggested questions"), - "more_like_this": fields.Raw(description="More like this configuration"), - "speech_to_text": fields.Raw(description="Speech to text configuration"), - "text_to_speech": fields.Raw(description="Text to speech configuration"), - "retrieval_model": fields.Raw(description="Retrieval model configuration"), - "tools": fields.List(fields.Raw(), description="Available tools"), - "dataset_configs": fields.Raw(description="Dataset configurations"), - "agent_mode": fields.Raw(description="Agent mode configuration"), - }, - ) - ) + @console_ns.expect(console_ns.models[ModelConfigRequest.__name__]) @console_ns.response(200, "Model configuration updated successfully") @console_ns.response(400, "Invalid configuration") @console_ns.response(404, "App not found") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 7f44a99ff1..9991d78d94 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,11 +1,12 @@ from typing import Literal -from flask_restx import Resource, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import NotFound from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( @@ -15,13 +16,11 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db -from fields.app_fields import app_site_fields +from fields.base import ResponseModel from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import Site -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AppSiteUpdatePayload(BaseModel): title: str | None = Field(default=None) @@ -49,13 +48,26 @@ class AppSiteUpdatePayload(BaseModel): return supported_language(value) -console_ns.schema_model( - AppSiteUpdatePayload.__name__, - AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class AppSiteResponse(ResponseModel): + app_id: str + access_token: str | None = Field(default=None, validation_alias="code") + code: str | None = None + title: str + icon: str | None = None + icon_background: str | None = None + description: str | None = None + default_language: str + customize_domain: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + customize_token_strategy: str + prompt_public: bool + show_workflow_steps: bool + use_icon_as_answer_icon: bool -# Register model for flask_restx to avoid dict type issues in Swagger -app_site_model = console_ns.model("AppSite", app_site_fields) + +register_schema_models(console_ns, AppSiteUpdatePayload, AppSiteResponse) @console_ns.route("/apps//site") @@ -64,7 +76,7 @@ class AppSite(Resource): @console_ns.doc(description="Update application site configuration") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__]) - @console_ns.response(200, "Site configuration updated successfully", app_site_model) + @console_ns.response(200, "Site configuration updated successfully", console_ns.models[AppSiteResponse.__name__]) @console_ns.response(403, "Insufficient permissions") @console_ns.response(404, "App not found") @setup_required @@ -72,7 +84,6 @@ class AppSite(Resource): @edit_permission_required @account_initialization_required @get_app_model - @marshal_with(app_site_model) def post(self, app_model): args = AppSiteUpdatePayload.model_validate(console_ns.payload or {}) current_user, _ = current_account_with_tenant() @@ -106,7 +117,7 @@ class AppSite(Resource): site.updated_at = naive_utc_now() db.session.commit() - return site + return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//site/access-token-reset") @@ -114,7 +125,7 @@ class AppSiteAccessTokenReset(Resource): @console_ns.doc("reset_app_site_access_token") @console_ns.doc(description="Reset access token for application site") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.response(200, "Access token reset successfully", app_site_model) + @console_ns.response(200, "Access token reset successfully", console_ns.models[AppSiteResponse.__name__]) @console_ns.response(403, "Insufficient permissions (admin/owner required)") @console_ns.response(404, "App or site not found") @setup_required @@ -122,7 +133,6 @@ class AppSiteAccessTokenReset(Resource): @is_admin_or_owner_required @account_initialization_required @get_app_model - @marshal_with(app_site_model) def post(self, app_model): current_user, _ = current_account_with_tenant() site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) @@ -135,4 +145,4 @@ class AppSiteAccessTokenReset(Resource): site.updated_at = naive_utc_now() db.session.commit() - return site + return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index da8d25c2eb..68dd8b7a8d 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -4,11 +4,7 @@ from collections.abc import Sequence from typing import Any from flask import abort, request -from flask_restx import Resource, fields, marshal_with -from graphon.enums import NodeType -from graphon.file import File -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.utils.encoders import jsonable_encoder +from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field, ValidationError, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -39,7 +35,13 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields +from fields.online_user_fields import online_user_list_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from graphon.enums import NodeType +from graphon.file import File +from graphon.file import helpers as file_helpers +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, uuid_value @@ -47,6 +49,7 @@ from libs.login import current_account_with_tenant, login_required from models import App from models.model import AppMode from models.workflow import Workflow +from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -57,6 +60,8 @@ _file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" +MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000 +WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50 # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -150,6 +155,19 @@ class ConvertToWorkflowPayload(BaseModel): icon_background: str | None = None +class WorkflowFeaturesPayload(BaseModel): + features: dict[str, Any] = Field(..., description="Workflow feature configuration") + + +class WorkflowOnlineUsersPayload(BaseModel): + app_ids: list[str] = Field(default_factory=list, description="App IDs") + + @field_validator("app_ids") + @classmethod + def normalize_app_ids(cls, app_ids: list[str]) -> list[str]: + return list(dict.fromkeys(app_id.strip() for app_id in app_ids if app_id.strip())) + + class DraftWorkflowTriggerRunPayload(BaseModel): node_id: str @@ -173,6 +191,8 @@ reg(DefaultBlockConfigQuery) reg(ConvertToWorkflowPayload) reg(WorkflowListQuery) reg(WorkflowUpdatePayload) +reg(WorkflowFeaturesPayload) +reg(WorkflowOnlineUsersPayload) reg(DraftWorkflowTriggerRunPayload) reg(DraftWorkflowTriggerRunAllPayload) @@ -931,6 +951,32 @@ class ConvertToWorkflowApi(Resource): } +@console_ns.route("/apps//workflows/draft/features") +class WorkflowFeaturesApi(Resource): + """Update draft workflow features.""" + + @console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__]) + @console_ns.doc("update_workflow_features") + @console_ns.doc(description="Update draft workflow features") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Workflow features updated successfully") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App): + current_user, _ = current_account_with_tenant() + + args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {}) + features = args.features + + workflow_service = WorkflowService() + workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user) + + return {"result": "success"} + + @console_ns.route("/apps//workflows") class PublishedAllWorkflowApi(Resource): @console_ns.expect(console_ns.models[WorkflowListQuery.__name__]) @@ -942,7 +988,6 @@ class PublishedAllWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_pagination_model) @edit_permission_required def get(self, app_model: App): """ @@ -970,9 +1015,10 @@ class PublishedAllWorkflowApi(Resource): user_id=user_id, named_only=named_only, ) + serialized_workflows = marshal(workflows, workflow_fields_copy) return { - "items": workflows, + "items": serialized_workflows, "page": page, "limit": limit, "has_more": has_more, @@ -1340,3 +1386,73 @@ class DraftWorkflowTriggerRunAllApi(Resource): "status": "error", } ), 400 + + +@console_ns.route("/apps/workflows/online-users") +class WorkflowOnlineUsersApi(Resource): + @console_ns.expect(console_ns.models[WorkflowOnlineUsersPayload.__name__]) + @console_ns.doc("get_workflow_online_users") + @console_ns.doc(description="Get workflow online users") + @setup_required + @login_required + @account_initialization_required + @marshal_with(online_user_list_fields) + def post(self): + args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {}) + + app_ids = args.app_ids + if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS: + raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS} app_ids are allowed per request.") + + if not app_ids: + return {"data": []} + + _, current_tenant_id = current_account_with_tenant() + workflow_service = WorkflowService() + accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id) + ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids] + + users_json_by_app_id: dict[str, Any] = {} + for start_index in range(0, len(ordered_accessible_app_ids), WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE): + app_id_batch = ordered_accessible_app_ids[ + start_index : start_index + WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE + ] + pipe = redis_client.pipeline(transaction=False) + for app_id in app_id_batch: + pipe.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}") + + users_json_batch = pipe.execute() + for app_id, users_json in zip(app_id_batch, users_json_batch): + users_json_by_app_id[app_id] = users_json + + results = [] + for app_id in ordered_accessible_app_ids: + users_json = users_json_by_app_id.get(app_id, {}) + + users = [] + for _, user_info_json in users_json.items(): + try: + user_info = json.loads(user_info_json) + except Exception: + continue + + if not isinstance(user_info, dict): + continue + + avatar = user_info.get("avatar") + if isinstance(avatar, str) and avatar and not avatar.startswith(("http://", "https://")): + try: + user_info["avatar"] = file_helpers.get_signed_file_url(avatar) + except Exception as exc: + logger.warning( + "Failed to sign workflow online user avatar; using original value. " + "app_id=%s avatar=%s error=%s", + app_id, + avatar, + exc, + ) + + users.append(user_info) + results.append({"app_id": app_id, "users": users}) + + return {"data": results} diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 3b24c2a402..4b39590235 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -1,27 +1,26 @@ from datetime import datetime +from typing import Any from dateutil.parser import isoparse from flask import request -from flask_restx import Resource, marshal_with -from graphon.enums import WorkflowExecutionStatus +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from fields.workflow_app_log_fields import ( - build_workflow_app_log_pagination_model, - build_workflow_archived_log_pagination_model, -) +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount +from graphon.enums import WorkflowExecutionStatus from libs.login import login_required from models import App from models.model import AppMode from services.workflow_app_service import WorkflowAppService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class WorkflowAppLogQuery(BaseModel): keyword: str | None = Field(default=None, description="Search keyword for filtering logs") @@ -58,13 +57,113 @@ class WorkflowAppLogQuery(BaseModel): raise ValueError("Invalid boolean value for detail") -console_ns.schema_model( - WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None -# Register model for flask_restx to avoid dict type issues in Swagger -workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) -workflow_archived_log_pagination_model = build_workflow_archived_log_pagination_model(console_ns) + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowRunForArchivedLogResponse(ResponseModel): + id: str + status: str | None = None + triggered_from: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: Any = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowArchivedLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForArchivedLogResponse | None = None + trigger_metadata: Any = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +class WorkflowArchivedLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowArchivedLogPartialResponse] + + +register_schema_models( + console_ns, + WorkflowAppLogQuery, + WorkflowRunForLogResponse, + WorkflowRunForArchivedLogResponse, + WorkflowAppLogPartialResponse, + WorkflowArchivedLogPartialResponse, + WorkflowAppLogPaginationResponse, + WorkflowArchivedLogPaginationResponse, +) @console_ns.route("/apps//workflow-app-logs") @@ -73,12 +172,15 @@ class WorkflowAppLogApi(Resource): @console_ns.doc(description="Get workflow application execution logs") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) - @console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model) + @console_ns.response( + 200, + "Workflow app logs retrieved successfully", + console_ns.models[WorkflowAppLogPaginationResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) - @marshal_with(workflow_app_log_pagination_model) def get(self, app_model: App): """ Get workflow app logs @@ -87,7 +189,7 @@ class WorkflowAppLogApi(Resource): # get paginate workflow app logs workflow_app_service = WorkflowAppService() - with sessionmaker(db.engine).begin() as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( session=session, app_model=app_model, @@ -102,7 +204,9 @@ class WorkflowAppLogApi(Resource): created_by_account=args.created_by_account, ) - return workflow_app_log_pagination + return WorkflowAppLogPaginationResponse.model_validate( + workflow_app_log_pagination, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/apps//workflow-archived-logs") @@ -111,12 +215,15 @@ class WorkflowArchivedLogApi(Resource): @console_ns.doc(description="Get workflow archived execution logs") @console_ns.doc(params={"app_id": "Application ID"}) @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) - @console_ns.response(200, "Workflow archived logs retrieved successfully", workflow_archived_log_pagination_model) + @console_ns.response( + 200, + "Workflow archived logs retrieved successfully", + console_ns.models[WorkflowArchivedLogPaginationResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) - @marshal_with(workflow_archived_log_pagination_model) def get(self, app_model: App): """ Get workflow archived logs @@ -124,7 +231,7 @@ class WorkflowArchivedLogApi(Resource): args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore workflow_app_service = WorkflowAppService() - with sessionmaker(db.engine).begin() as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs( session=session, app_model=app_model, @@ -132,4 +239,6 @@ class WorkflowArchivedLogApi(Resource): limit=args.limit, ) - return workflow_app_log_pagination + return WorkflowArchivedLogPaginationResponse.model_validate( + workflow_app_log_pagination, from_attributes=True + ).model_dump(mode="json") diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py new file mode 100644 index 0000000000..e7c3e982a6 --- /dev/null +++ b/api/controllers/console/app/workflow_comment.py @@ -0,0 +1,335 @@ +import logging + +from flask_restx import Resource, marshal_with +from pydantic import BaseModel, Field, TypeAdapter + +from controllers.common.schema import register_schema_models +from controllers.console import console_ns +from controllers.console.app.wraps import get_app_model +from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from fields.member_fields import AccountWithRole +from fields.workflow_comment_fields import ( + workflow_comment_basic_fields, + workflow_comment_create_fields, + workflow_comment_detail_fields, + workflow_comment_reply_create_fields, + workflow_comment_reply_update_fields, + workflow_comment_resolve_fields, + workflow_comment_update_fields, +) +from libs.login import current_user, login_required +from models import App +from services.account_service import TenantService +from services.workflow_comment_service import WorkflowCommentService + +logger = logging.getLogger(__name__) +DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + + +class WorkflowCommentCreatePayload(BaseModel): + content: str = Field(..., description="Comment content") + position_x: float = Field(..., description="Comment X position") + position_y: float = Field(..., description="Comment Y position") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentUpdatePayload(BaseModel): + content: str = Field(..., description="Comment content") + position_x: float | None = Field(default=None, description="Comment X position") + position_y: float | None = Field(default=None, description="Comment Y position") + mentioned_user_ids: list[str] | None = Field( + default=None, + description="Mentioned user IDs. Omit to keep existing mentions.", + ) + + +class WorkflowCommentReplyPayload(BaseModel): + content: str = Field(..., description="Reply content") + mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs") + + +class WorkflowCommentMentionUsersPayload(BaseModel): + users: list[AccountWithRole] + + +for model in ( + WorkflowCommentCreatePayload, + WorkflowCommentUpdatePayload, + WorkflowCommentReplyPayload, +): + console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload) + +workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields) +workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields) +workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields) +workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields) +workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields) +workflow_comment_reply_create_model = console_ns.model( + "WorkflowCommentReplyCreate", workflow_comment_reply_create_fields +) +workflow_comment_reply_update_model = console_ns.model( + "WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields +) + + +@console_ns.route("/apps//workflow/comments") +class WorkflowCommentListApi(Resource): + """API for listing and creating workflow comments.""" + + @console_ns.doc("list_workflow_comments") + @console_ns.doc(description="Get all comments for a workflow") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_basic_model, envelope="data") + def get(self, app_model: App): + """Get all comments for a workflow.""" + comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id) + + return comments + + @console_ns.doc("create_workflow_comment") + @console_ns.doc(description="Create a new workflow comment") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__]) + @console_ns.response(201, "Comment created successfully", workflow_comment_create_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_create_model) + @edit_permission_required + def post(self, app_model: App): + """Create a new workflow comment.""" + payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.create_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + created_by=current_user.id, + content=payload.content, + position_x=payload.position_x, + position_y=payload.position_y, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result, 201 + + +@console_ns.route("/apps//workflow/comments/") +class WorkflowCommentDetailApi(Resource): + """API for managing individual workflow comments.""" + + @console_ns.doc("get_workflow_comment") + @console_ns.doc(description="Get a specific workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_detail_model) + def get(self, app_model: App, comment_id: str): + """Get a specific workflow comment.""" + comment = WorkflowCommentService.get_comment( + tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id + ) + + return comment + + @console_ns.doc("update_workflow_comment") + @console_ns.doc(description="Update a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__]) + @console_ns.response(200, "Comment updated successfully", workflow_comment_update_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_update_model) + @edit_permission_required + def put(self, app_model: App, comment_id: str): + """Update a workflow comment.""" + payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.update_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + content=payload.content, + position_x=payload.position_x, + position_y=payload.position_y, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result + + @console_ns.doc("delete_workflow_comment") + @console_ns.doc(description="Delete a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(204, "Comment deleted successfully") + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @edit_permission_required + def delete(self, app_model: App, comment_id: str): + """Delete a workflow comment.""" + WorkflowCommentService.delete_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return {"result": "success"}, 204 + + +@console_ns.route("/apps//workflow/comments//resolve") +class WorkflowCommentResolveApi(Resource): + """API for resolving and reopening workflow comments.""" + + @console_ns.doc("resolve_workflow_comment") + @console_ns.doc(description="Resolve a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_resolve_model) + @edit_permission_required + def post(self, app_model: App, comment_id: str): + """Resolve a workflow comment.""" + comment = WorkflowCommentService.resolve_comment( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + user_id=current_user.id, + ) + + return comment + + +@console_ns.route("/apps//workflow/comments//replies") +class WorkflowCommentReplyApi(Resource): + """API for managing comment replies.""" + + @console_ns.doc("create_workflow_comment_reply") + @console_ns.doc(description="Add a reply to a workflow comment") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__]) + @console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_reply_create_model) + @edit_permission_required + def post(self, app_model: App, comment_id: str): + """Add a reply to a workflow comment.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {}) + + result = WorkflowCommentService.create_reply( + comment_id=comment_id, + content=payload.content, + created_by=current_user.id, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return result, 201 + + +@console_ns.route("/apps//workflow/comments//replies/") +class WorkflowCommentReplyDetailApi(Resource): + """API for managing individual comment replies.""" + + @console_ns.doc("update_workflow_comment_reply") + @console_ns.doc(description="Update a comment reply") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) + @console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__]) + @console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @marshal_with(workflow_comment_reply_update_model) + @edit_permission_required + def put(self, app_model: App, comment_id: str, reply_id: str): + """Update a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {}) + + reply = WorkflowCommentService.update_reply( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + reply_id=reply_id, + user_id=current_user.id, + content=payload.content, + mentioned_user_ids=payload.mentioned_user_ids, + ) + + return reply + + @console_ns.doc("delete_workflow_comment_reply") + @console_ns.doc(description="Delete a comment reply") + @console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"}) + @console_ns.response(204, "Reply deleted successfully") + @login_required + @setup_required + @account_initialization_required + @get_app_model() + @edit_permission_required + def delete(self, app_model: App, comment_id: str, reply_id: str): + """Delete a comment reply.""" + # Validate comment access first + WorkflowCommentService.validate_comment_access( + comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id + ) + + WorkflowCommentService.delete_reply( + tenant_id=current_user.current_tenant_id, + app_id=app_model.id, + comment_id=comment_id, + reply_id=reply_id, + user_id=current_user.id, + ) + + return {"result": "success"}, 204 + + +@console_ns.route("/apps//workflow/comments/mention-users") +class WorkflowCommentMentionUsersApi(Resource): + """API for getting mentionable users for workflow comments.""" + + @console_ns.doc("workflow_comment_mention_users") + @console_ns.doc(description="Get all users in current tenant for mentions") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response( + 200, "Mentionable users retrieved successfully", console_ns.models[WorkflowCommentMentionUsersPayload.__name__] + ) + @login_required + @setup_required + @account_initialization_required + @get_app_model() + def get(self, app_model: App): + """Get all users in current tenant for mentions.""" + members = TenantService.get_tenant_members(current_user.current_tenant) + users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) + response = WorkflowCommentMentionUsersPayload(users=users) + return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 9771d6f1e5..c688a69074 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -1,14 +1,10 @@ import logging from collections.abc import Callable from functools import wraps -from typing import Any +from typing import Any, TypedDict from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.file import helpers as file_helpers -from graphon.variables.segment_group import SegmentGroup -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker @@ -22,8 +18,13 @@ from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.file_access import DatabaseFileAccessController from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db +from factories import variable_factory from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.file import helpers as file_helpers +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import App, AppMode from models.workflow import WorkflowDraftVariable @@ -45,6 +46,16 @@ class WorkflowDraftVariableUpdatePayload(BaseModel): value: Any | None = Field(default=None, description="Variable value") +class ConversationVariableUpdatePayload(BaseModel): + conversation_variables: list[dict[str, Any]] = Field( + ..., description="Conversation variables for the draft workflow" + ) + + +class EnvironmentVariableUpdatePayload(BaseModel): + environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow") + + console_ns.schema_model( WorkflowDraftVariableListQuery.__name__, WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), @@ -53,17 +64,26 @@ console_ns.schema_model( WorkflowDraftVariableUpdatePayload.__name__, WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), ) +console_ns.schema_model( + ConversationVariableUpdatePayload.__name__, + ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) +console_ns.schema_model( + EnvironmentVariableUpdatePayload.__name__, + EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) def _convert_values_to_json_serializable_object(value: Segment): - if isinstance(value, FileSegment): - return value.value.model_dump() - elif isinstance(value, ArrayFileSegment): - return [i.model_dump() for i in value.value] - elif isinstance(value, SegmentGroup): - return [_convert_values_to_json_serializable_object(i) for i in value.value] - else: - return value.value + match value: + case FileSegment(): + return value.value.model_dump() + case ArrayFileSegment(): + return [i.model_dump() for i in value.value] + case SegmentGroup(): + return [_convert_values_to_json_serializable_object(i) for i in value.value] + case _: + return value.value def _serialize_var_value(variable: WorkflowDraftVariable): @@ -83,10 +103,17 @@ def _serialize_var_value(variable: WorkflowDraftVariable): def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: value_type = workflow_draft_var.value_type - return value_type.exposed_type().value + return str(value_type.exposed_type()) -def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None: +class FullContentDict(TypedDict): + size_bytes: int | None + value_type: str + length: int | None + download_url: str + + +def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict | None: """Serialize full_content information for large variables.""" if not variable.is_truncated(): return None @@ -94,12 +121,13 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None: variable_file = variable.variable_file assert variable_file is not None - return { + result: FullContentDict = { "size_bytes": variable_file.size, - "value_type": variable_file.value_type.exposed_type().value, + "value_type": str(variable_file.value_type.exposed_type()), "length": variable_file.length, "download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True), } + return result def _ensure_variable_access( @@ -502,6 +530,34 @@ class ConversationVariableCollectionApi(Resource): db.session.commit() return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) + @console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__]) + @console_ns.doc("update_conversation_variables") + @console_ns.doc(description="Update conversation variables for workflow draft") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Conversation variables updated successfully") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_app_model(mode=AppMode.ADVANCED_CHAT) + def post(self, app_model: App): + payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {}) + + workflow_service = WorkflowService() + + conversation_variables_list = payload.conversation_variables + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list + ] + + workflow_service.update_draft_workflow_conversation_variables( + app_model=app_model, + account=current_user, + conversation_variables=conversation_variables, + ) + + return {"result": "success"} + @console_ns.route("/apps//workflows/draft/system-variables") class SystemVariableCollectionApi(Resource): @@ -543,7 +599,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.exposed_type().value, + "value_type": str(v.value_type.exposed_type()), "value": v.value, # Do not track edited for env vars. "edited": False, @@ -553,3 +609,31 @@ class EnvironmentVariableCollectionApi(Resource): ) return {"items": env_vars_list} + + @console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__]) + @console_ns.doc("update_environment_variables") + @console_ns.doc(description="Update environment variables for workflow draft") + @console_ns.doc(params={"app_id": "Application ID"}) + @console_ns.response(200, "Environment variables updated successfully") + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def post(self, app_model: App): + payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {}) + + workflow_service = WorkflowService() + + environment_variables_list = payload.environment_variables + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + + workflow_service.update_draft_workflow_environment_variables( + app_model=app_model, + account=current_user, + environment_variables=environment_variables, + ) + + return {"result": "success"} diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 83e8bedc11..6748d95d6b 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -3,8 +3,6 @@ from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -28,6 +26,8 @@ from fields.workflow_run_fields import ( workflow_run_node_execution_list_fields, workflow_run_pagination_fields, ) +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage from libs.custom_inputs import time_duration from libs.helper import uuid_value @@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME -from services.workflow_run_service import WorkflowRunService +from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService def _build_backstage_input_url(form_token: str | None) -> str | None: @@ -214,7 +214,11 @@ class AdvancedChatAppWorkflowRunListApi(Resource): Get advanced chat app workflow run list """ args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - args = args_model.model_dump(exclude_none=True) + args: WorkflowRunListArgs = {"limit": args_model.limit} + if args_model.last_id is not None: + args["last_id"] = args_model.last_id + if args_model.status is not None: + args["status"] = args_model.status # Default to DEBUGGING if not specified triggered_from = ( @@ -356,7 +360,11 @@ class WorkflowRunListApi(Resource): Get workflow run list """ args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - args = args_model.model_dump(exclude_none=True) + args: WorkflowRunListArgs = {"limit": args_model.limit} + if args_model.last_id is not None: + args["last_id"] = args_model.last_id + if args_model.status is not None: + args["status"] = args_model.status # Default to DEBUGGING for workflow if not specified (backward compatibility) triggered_from = ( diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index e4a6afae1e..a6715fa200 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -1,16 +1,17 @@ import logging +from datetime import datetime from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel +from flask_restx import Resource +from pydantic import BaseModel, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound from configs import dify_config -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from extensions.ext_database import db -from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields +from fields.base import ResponseModel from libs.login import current_user, login_required from models.enums import AppTriggerStatus from models.model import Account, App, AppMode @@ -21,15 +22,6 @@ from ..app.wraps import get_app_model from ..wraps import account_initialization_required, edit_permission_required, setup_required logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - -trigger_model = get_or_create_model("WorkflowTrigger", trigger_fields) - -triggers_list_fields_copy = triggers_list_fields.copy() -triggers_list_fields_copy["data"] = fields.List(fields.Nested(trigger_model)) -triggers_list_model = get_or_create_model("WorkflowTriggerList", triggers_list_fields_copy) - -webhook_trigger_model = get_or_create_model("WebhookTrigger", webhook_trigger_fields) class Parser(BaseModel): @@ -41,10 +33,52 @@ class ParserEnable(BaseModel): enable_trigger: bool -console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class WorkflowTriggerResponse(ResponseModel): + id: str + trigger_type: str + title: str + node_id: str + provider_name: str + icon: str + status: str + created_at: datetime | None = None + updated_at: datetime | None = None -console_ns.schema_model( - ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) + @field_validator("id", "trigger_type", "title", "node_id", "provider_name", "icon", "status", mode="before") + @classmethod + def _normalize_string_fields(cls, value: object) -> str: + if isinstance(value, str): + return value + return str(value) + + +class WorkflowTriggerListResponse(ResponseModel): + data: list[WorkflowTriggerResponse] + + +class WebhookTriggerResponse(ResponseModel): + id: str + webhook_id: str + webhook_url: str + webhook_debug_url: str + node_id: str + created_at: datetime | None = None + + @field_validator("id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", mode="before") + @classmethod + def _normalize_string_fields(cls, value: object) -> str: + if isinstance(value, str): + return value + return str(value) + + +register_schema_models( + console_ns, + Parser, + ParserEnable, + WorkflowTriggerResponse, + WorkflowTriggerListResponse, + WebhookTriggerResponse, ) @@ -57,14 +91,14 @@ class WebhookTriggerApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(webhook_trigger_model) + @console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__]) def get(self, app_model: App): """Get webhook trigger for a node""" args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore node_id = args.node_id - with sessionmaker(db.engine).begin() as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: # Get webhook trigger for this app and node webhook_trigger = session.scalar( select(WorkflowWebhookTrigger) @@ -78,7 +112,7 @@ class WebhookTriggerApi(Resource): if not webhook_trigger: raise NotFound("Webhook trigger not found for this node") - return webhook_trigger + return WebhookTriggerResponse.model_validate(webhook_trigger, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//triggers") @@ -89,13 +123,13 @@ class AppTriggersApi(Resource): @login_required @account_initialization_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(triggers_list_model) + @console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__]) def get(self, app_model: App): """Get app triggers list""" assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None - with sessionmaker(db.engine).begin() as session: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: # Get all triggers for this app using select API triggers = ( session.execute( @@ -118,7 +152,9 @@ class AppTriggersApi(Resource): else: trigger.icon = "" # type: ignore - return {"data": triggers} + return WorkflowTriggerListResponse.model_validate({"data": triggers}, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//trigger-enable") @@ -129,7 +165,7 @@ class AppTriggerEnableApi(Resource): @account_initialization_required @edit_permission_required @get_app_model(mode=AppMode.WORKFLOW) - @marshal_with(trigger_model) + @console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__]) def post(self, app_model: App): """Update app trigger (enable/disable)""" args = ParserEnable.model_validate(console_ns.payload) @@ -160,4 +196,4 @@ class AppTriggerEnableApi(Resource): else: trigger.icon = "" # type: ignore - return trigger + return WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump(mode="json") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index f741107b87..f7061f820f 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,8 +1,11 @@ +from typing import Any + from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from constants.languages import supported_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.error import AlreadyActivateError from extensions.ext_database import db @@ -11,8 +14,6 @@ from libs.helper import EmailStr, timezone from models import AccountStatus from services.account_service import RegisterService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ActivateCheckQuery(BaseModel): workspace_id: str | None = Field(default=None) @@ -39,8 +40,16 @@ class ActivatePayload(BaseModel): return timezone(value) -for model in (ActivateCheckQuery, ActivatePayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +class ActivationCheckResponse(BaseModel): + is_valid: bool = Field(description="Whether token is valid") + data: dict[str, Any] | None = Field(default=None, description="Activation data if valid") + + +class ActivationResponse(BaseModel): + result: str = Field(description="Operation result") + + +register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse) @console_ns.route("/activate/check") @@ -51,13 +60,7 @@ class ActivateCheckApi(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "ActivationCheckResponse", - { - "is_valid": fields.Boolean(description="Whether token is valid"), - "data": fields.Raw(description="Activation data if valid"), - }, - ), + console_ns.models[ActivationCheckResponse.__name__], ) def get(self): args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -95,12 +98,7 @@ class ActivateApi(Resource): @console_ns.response( 200, "Account activated successfully", - console_ns.model( - "ActivationResponse", - { - "result": fields.String(description="Operation result"), - }, - ), + console_ns.models[ActivationResponse.__name__], ) @console_ns.response(400, "Already activated or invalid token") def post(self): diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 9e7faa09c5..1fd781b4fc 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,6 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -14,7 +13,6 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) -from extensions.ext_database import db from libs.helper import EmailStr, extract_remote_ip from libs.password import valid_password from models import Account @@ -73,8 +71,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,17 +142,16 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - raise EmailAlreadyInUseError() - else: - account = self._create_new_account(normalized_email, args.password_confirm) - if not account: - raise AccountNotFoundError() - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - AccountService.reset_login_error_rate_limit(normalized_email) + if account: + raise EmailAlreadyInUseError() + else: + account = self._create_new_account(normalized_email, args.password_confirm) + if not account: + raise AccountNotFoundError() + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 63bc98b53f..ed390a5f89 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,6 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -85,8 +84,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) token = AccountService.send_reset_password_email( account=account, @@ -184,17 +182,18 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt, session) - else: - raise AccountNotFound() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AccountNotFound() return {"result": "success"} - def _update_existing_account(self, account, password_hashed, salt, session): + def _update_existing_account(self, account, password_hashed, salt): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 962cc83b0e..8216b3d0da 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,7 +1,10 @@ +import logging + import flask_login from flask import make_response, request from flask_restx import Resource from pydantic import BaseModel, Field +from werkzeug.exceptions import Unauthorized import services from configs import dify_config @@ -42,12 +45,13 @@ from libs.token import ( ) from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService from services.billing_service import BillingService -from services.entities.auth_entities import LoginPayloadBase +from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +logger = logging.getLogger(__name__) class LoginPayload(LoginPayloadBase): @@ -91,10 +95,12 @@ class LoginApi(Resource): normalized_email = request_email.lower() if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): + _log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE) raise AccountInFreezeError() is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email) if is_login_error_rate_limit: + _log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED) raise EmailPasswordLoginLimitError() invite_token = args.invite_token @@ -110,14 +116,20 @@ class LoginApi(Resource): invitee_email = data.get("email") if data else None invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email if invitee_email_normalized != normalized_email: + _log_console_login_failure( + email=normalized_email, + reason=LoginFailureReason.INVALID_INVITATION_EMAIL, + ) raise InvalidEmailError() account = _authenticate_account_with_case_fallback( request_email, normalized_email, args.password, invite_token ) except services.errors.account.AccountLoginError: + _log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED) raise AccountBannedError() except services.errors.account.AccountPasswordError as exc: AccountService.add_login_error_rate_limit(normalized_email) + _log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS) raise AuthenticationFailedError() from exc # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) @@ -240,20 +252,27 @@ class EmailCodeLoginApi(Resource): token_data = AccountService.get_email_code_login_data(args.token) if token_data is None: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN) raise InvalidTokenError() token_email = token_data.get("email") normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email if normalized_token_email != user_email: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH) raise InvalidEmailError() if token_data["code"] != args.code: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE) raise EmailCodeError() AccountService.revoke_email_code_login_token(args.token) try: account = _get_account_with_case_fallback(original_email) + except Unauthorized as exc: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED) + raise AccountBannedError() from exc except AccountRegisterError: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE) raise AccountInFreezeError() if account: tenants = TenantService.get_join_tenants(account) @@ -279,6 +298,7 @@ class EmailCodeLoginApi(Resource): except WorkSpaceNotAllowedCreateError: raise NotAllowedCreateWorkspace() except AccountRegisterError: + _log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE) raise AccountInFreezeError() except WorkspacesLimitExceededError: raise WorkspacesLimitExceeded() @@ -336,3 +356,12 @@ def _authenticate_account_with_case_fallback( if original_email == normalized_email: raise return AccountService.authenticate(normalized_email, password, invite_token) + + +def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None: + logger.warning( + "Console login failed: email=%s reason=%s ip_address=%s", + email, + reason, + extract_remote_ip(request), + ) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5c7011fd22..d31fb4a46c 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,6 @@ import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -180,8 +179,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(user_info.email) return account diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index b55cda4244..727428c8e7 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -5,11 +5,11 @@ from typing import Concatenate from flask import jsonify, request from flask.typing import ResponseReturnValue from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 23c01eedb1..45de338559 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -2,18 +2,17 @@ import base64 from typing import Literal from flask import request -from flask_restx import Resource, fields +from flask_restx import Resource from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from enums.cloud_plan import CloudPlan from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class SubscriptionQuery(BaseModel): plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan") @@ -24,8 +23,7 @@ class PartnerTenantsPayload(BaseModel): click_id: str = Field(..., description="Click Id from partner referral link") -for model in (SubscriptionQuery, PartnerTenantsPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, SubscriptionQuery, PartnerTenantsPayload) @console_ns.route("/billing/subscription") @@ -58,12 +56,7 @@ class PartnerTenants(Resource): @console_ns.doc("sync_partner_tenants_bindings") @console_ns.doc(description="Sync partner tenants bindings") @console_ns.doc(params={"partner_key": "Partner key"}) - @console_ns.expect( - console_ns.model( - "SyncPartnerTenantsBindingsRequest", - {"click_id": fields.String(required=True, description="Click Id from partner referral link")}, - ) - ) + @console_ns.expect(console_ns.models[PartnerTenantsPayload.__name__]) @console_ns.response(200, "Tenants synced to partner successfully") @console_ns.response(400, "Invalid partner information") @setup_required diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index e623722b23..ed3c1a59d4 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -162,7 +162,9 @@ class DataSourceApi(Resource): binding_id = str(binding_id) with sessionmaker(db.engine, expire_on_commit=False).begin() as session: data_source_binding = session.execute( - select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id) + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id + ) ).scalar_one_or_none() if data_source_binding is None: raise NotFound("Data source binding not found.") @@ -222,11 +224,11 @@ class DataSourceNotionListApi(Resource): raise ValueError("Dataset is not notion type.") documents = session.scalars( - select(Document).filter_by( - dataset_id=query.dataset_id, - tenant_id=current_tenant_id, - data_source_type="notion_import", - enabled=True, + select(Document).where( + Document.dataset_id == query.dataset_id, + Document.tenant_id == current_tenant_id, + Document.data_source_type == "notion_import", + Document.enabled.is_(True), ) ).all() if documents: diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f23c7eb431..d001dfba64 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -2,7 +2,6 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound @@ -11,10 +10,7 @@ import services from configs import dify_config from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns -from controllers.console.apikey import ( - api_key_item_model, - api_key_list_model, -) +from controllers.console.apikey import ApiKeyItem, ApiKeyList from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from controllers.console.wraps import ( @@ -52,7 +48,9 @@ from fields.dataset_fields import ( weighted_score_fields, ) from fields.document_fields import document_status_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant, login_required +from libs.url_utils import normalize_api_base_url from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models.dataset import DatasetPermission, DatasetPermissionEnum from models.enums import ApiTokenType, SegmentStatus @@ -785,23 +783,23 @@ class DatasetApiKeyApi(Resource): @console_ns.doc("get_dataset_api_keys") @console_ns.doc(description="Get dataset API keys") - @console_ns.response(200, "API keys retrieved successfully", api_key_list_model) + @console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_key_list_model) def get(self): _, current_tenant_id = current_account_with_tenant() keys = db.session.scalars( select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id) ).all() - return {"items": keys} + return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json") + @console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__]) + @console_ns.response(400, "Maximum keys exceeded") @setup_required @login_required @is_admin_or_owner_required @account_initialization_required - @marshal_with(api_key_item_model) def post(self): _, current_tenant_id = current_account_with_tenant() @@ -828,7 +826,7 @@ class DatasetApiKeyApi(Resource): api_token.type = self.resource_type db.session.add(api_token) db.session.commit() - return api_token, 200 + return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200 @console_ns.route("/datasets/api-keys/") @@ -892,7 +890,8 @@ class DatasetApiBaseUrlApi(Resource): @login_required @account_initialization_required def get(self): - return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} + base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/") + return {"api_base_url": normalize_api_base_url(base)} @console_ns.route("/datasets/retrieval-setting") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ab367d8483..3372a967d9 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -3,20 +3,19 @@ import logging from argparse import ArgumentTypeError from collections.abc import Sequence from contextlib import ExitStack +from datetime import datetime from typing import Any, Literal, cast -from uuid import UUID import sqlalchemy as sa from flask import request, send_file -from flask_restx import Resource, fields, marshal, marshal_with -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from pydantic import BaseModel, Field +from flask_restx import Resource, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.common.schema import get_or_create_model, register_schema_models +from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload +from controllers.common.schema import register_schema_models from controllers.console import console_ns from core.errors.error import ( LLMBadRequestError, @@ -31,14 +30,14 @@ from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db -from fields.dataset_fields import dataset_fields +from fields.base import ResponseModel from fields.document_fields import ( - dataset_and_document_fields, document_fields, - document_metadata_fields, document_status_fields, document_with_segments_fields, ) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile @@ -71,31 +70,101 @@ from ..wraps import ( logger = logging.getLogger(__name__) -# NOTE: Keep constants near the top of the module for discoverability. -DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100 + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -# Register models for flask_restx to avoid dict type issues in Swagger -dataset_model = get_or_create_model("Dataset", dataset_fields) +def _normalize_enum(value: Any) -> Any: + if isinstance(value, str) or value is None: + return value + return getattr(value, "value", value) -document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields) -document_fields_copy = document_fields.copy() -document_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_model = get_or_create_model("Document", document_fields_copy) +class DatasetResponse(ResponseModel): + id: str + name: str + description: str | None = None + permission: str | None = None + data_source_type: str | None = None + indexing_technique: str | None = None + created_by: str | None = None + created_at: int | None = None -document_with_segments_fields_copy = document_with_segments_fields.copy() -document_with_segments_fields_copy["doc_metadata"] = fields.List( - fields.Nested(document_metadata_model), attribute="doc_metadata_details" -) -document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy) + @field_validator("data_source_type", "indexing_technique", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) -dataset_and_document_fields_copy = dataset_and_document_fields.copy() -dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model) -dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model)) -dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentMetadataResponse(ResponseModel): + id: str + name: str + type: str + value: str | None = None + + +class DocumentResponse(ResponseModel): + id: str + position: int | None = None + data_source_type: str | None = None + data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict") + data_source_detail_dict: Any = None + dataset_process_rule_id: str | None = None + name: str + created_from: str | None = None + created_by: str | None = None + created_at: int | None = None + tokens: int | None = None + indexing_status: str | None = None + error: str | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + archived: bool | None = None + display_status: str | None = None + word_count: int | None = None + hit_count: int | None = None + doc_form: str | None = None + doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details") + summary_index_status: str | None = None + need_summary: bool | None = None + + @field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> Any: + return _normalize_enum(value) + + @field_validator("doc_metadata", mode="before") + @classmethod + def _normalize_doc_metadata(cls, value: Any) -> list[Any]: + if value is None: + return [] + return value + + @field_validator("created_at", "disabled_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class DocumentWithSegmentsResponse(DocumentResponse): + process_rule_dict: Any = None + completed_segments: int | None = None + total_segments: int | None = None + + +class DatasetAndDocumentResponse(ResponseModel): + dataset: DatasetResponse + documents: list[DocumentResponse] + batch: str class DocumentRetryPayload(BaseModel): @@ -110,10 +179,9 @@ class GenerateSummaryPayload(BaseModel): document_list: list[str] -class DocumentBatchDownloadZipPayload(BaseModel): - """Request payload for bulk downloading documents as a zip archive.""" - - document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS) +class DocumentMetadataUpdatePayload(BaseModel): + doc_type: str | None = None + doc_metadata: Any = None class DocumentDatasetListParam(BaseModel): @@ -133,7 +201,13 @@ register_schema_models( DocumentRetryPayload, DocumentRenamePayload, GenerateSummaryPayload, + DocumentMetadataUpdatePayload, DocumentBatchDownloadZipPayload, + DatasetResponse, + DocumentMetadataResponse, + DocumentResponse, + DocumentWithSegmentsResponse, + DatasetAndDocumentResponse, ) @@ -280,7 +354,7 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id) + query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id) if status: query = DocumentService.apply_display_status_filter(query, status) @@ -366,10 +440,10 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) + @console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id = str(dataset_id) @@ -407,7 +481,9 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return {"dataset": dataset, "documents": documents, "batch": batch} + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @setup_required @login_required @@ -435,12 +511,13 @@ class DatasetInitApi(Resource): @console_ns.doc("init_dataset") @console_ns.doc(description="Initialize dataset with documents") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) - @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) + @console_ns.response( + 201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__] + ) @console_ns.response(400, "Invalid request parameters") @setup_required @login_required @account_initialization_required - @marshal_with(dataset_and_document_model) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self): @@ -488,9 +565,9 @@ class DatasetInitApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = {"dataset": dataset, "documents": documents, "batch": batch} - - return response + return DatasetAndDocumentResponse.model_validate( + {"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/datasets//documents//indexing-estimate") @@ -997,15 +1074,7 @@ class DocumentMetadataApi(DocumentResource): @console_ns.doc("update_document_metadata") @console_ns.doc(description="Update document metadata") @console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) - @console_ns.expect( - console_ns.model( - "UpdateDocumentMetadataRequest", - { - "doc_type": fields.String(description="Document type"), - "doc_metadata": fields.Raw(description="Document metadata"), - }, - ) - ) + @console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__]) @console_ns.response(200, "Document metadata updated successfully") @console_ns.response(404, "Document not found") @console_ns.response(403, "Permission denied") @@ -1018,10 +1087,10 @@ class DocumentMetadataApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - req_data = request.get_json() + req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {}) - doc_type = req_data.get("doc_type") - doc_metadata = req_data.get("doc_metadata") + doc_type = req_data.doc_type + doc_metadata = req_data.doc_metadata # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: @@ -1035,7 +1104,7 @@ class DocumentMetadataApi(DocumentResource): if not isinstance(doc_metadata, dict): raise ValueError("doc_metadata must be a dictionary.") - metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) + metadata_schema: dict[str, Any] = cast(dict[str, Any], DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]) document.doc_metadata = {} if doc_type == "others": @@ -1203,7 +1272,7 @@ class DocumentRenameApi(DocumentResource): @setup_required @login_required @account_initialization_required - @marshal_with(document_model) + @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__]) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) def post(self, dataset_id, document_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator @@ -1221,7 +1290,7 @@ class DocumentRenameApi(DocumentResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return document + return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json") @console_ns.route("/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index c5f4e3a6e2..2647bb1f5a 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,7 +2,6 @@ import uuid from flask import request from flask_restx import Resource, marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import String, cast, func, or_, select from sqlalchemy.dialects.postgresql import JSONB @@ -10,6 +9,7 @@ from werkzeug.exceptions import Forbidden, NotFound import services from configs import dify_config +from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ProviderNotInitializeError @@ -31,6 +31,7 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import escape_like_pattern from libs.login import current_account_with_tenant, login_required from models.dataset import ChildChunk, DocumentSegment @@ -82,14 +83,6 @@ class BatchImportPayload(BaseModel): upload_file_id: str -class ChildChunkCreatePayload(BaseModel): - content: str - - -class ChildChunkUpdatePayload(BaseModel): - content: str - - class ChildChunkBatchUpdatePayload(BaseModel): chunks: list[ChildChunkUpdateArgs] diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index e62be13c2f..36a7a4bb0e 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,13 +1,13 @@ -from flask_restx import Resource, fields +from __future__ import annotations -from controllers.common.schema import register_schema_model -from fields.hit_testing_fields import ( - child_chunk_fields, - document_fields, - files_fields, - hit_testing_record_fields, - segment_fields, -) +from datetime import datetime +from typing import Any + +from flask_restx import Resource +from pydantic import Field, field_validator + +from controllers.common.schema import register_schema_models +from fields.base import ResponseModel from libs.login import login_required from .. import console_ns @@ -18,39 +18,92 @@ from ..wraps import ( setup_required, ) -register_schema_model(console_ns, HitTestingPayload) + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -def _get_or_create_model(model_name: str, field_def): - """Get or create a flask_restx model to avoid dict type issues in Swagger.""" - existing = console_ns.models.get(model_name) - if existing is None: - existing = console_ns.model(model_name, field_def) - return existing +class HitTestingDocument(ResponseModel): + id: str | None = None + data_source_type: str | None = None + name: str | None = None + doc_type: str | None = None + doc_metadata: Any | None = None -# Register models for flask_restx to avoid dict type issues in Swagger -document_model = _get_or_create_model("HitTestingDocument", document_fields) +class HitTestingSegment(ResponseModel): + id: str | None = None + position: int | None = None + document_id: str | None = None + content: str | None = None + sign_content: str | None = None + answer: str | None = None + word_count: int | None = None + tokens: int | None = None + keywords: list[str] = Field(default_factory=list) + index_node_id: str | None = None + index_node_hash: str | None = None + hit_count: int | None = None + enabled: bool | None = None + disabled_at: int | None = None + disabled_by: str | None = None + status: str | None = None + created_by: str | None = None + created_at: int | None = None + indexing_at: int | None = None + completed_at: int | None = None + error: str | None = None + stopped_at: int | None = None + document: HitTestingDocument | None = None -segment_fields_copy = segment_fields.copy() -segment_fields_copy["document"] = fields.Nested(document_model) -segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy) + @field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields) -files_model = _get_or_create_model("HitTestingFile", files_fields) -hit_testing_record_fields_copy = hit_testing_record_fields.copy() -hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model) -hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model)) -hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model)) -hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy) +class HitTestingChildChunk(ResponseModel): + id: str | None = None + content: str | None = None + position: int | None = None + score: float | None = None -# Response model for hit testing API -hit_testing_response_fields = { - "query": fields.String, - "records": fields.List(fields.Nested(hit_testing_record_model)), -} -hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields) + +class HitTestingFile(ResponseModel): + id: str | None = None + name: str | None = None + size: int | None = None + extension: str | None = None + mime_type: str | None = None + source_url: str | None = None + + +class HitTestingRecord(ResponseModel): + segment: HitTestingSegment | None = None + child_chunks: list[HitTestingChildChunk] = Field(default_factory=list) + score: float | None = None + tsne_position: Any | None = None + files: list[HitTestingFile] = Field(default_factory=list) + summary: str | None = None + + +class HitTestingResponse(ResponseModel): + query: str + records: list[HitTestingRecord] = Field(default_factory=list) + + +register_schema_models( + console_ns, + HitTestingPayload, + HitTestingDocument, + HitTestingSegment, + HitTestingChildChunk, + HitTestingFile, + HitTestingRecord, + HitTestingResponse, +) @console_ns.route("/datasets//hit-testing") @@ -59,7 +112,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.expect(console_ns.models[HitTestingPayload.__name__]) - @console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model) + @console_ns.response( + 200, + "Hit testing completed successfully", + model=console_ns.models[HitTestingResponse.__name__], + ) @console_ns.response(404, "Dataset not found") @console_ns.response(400, "Invalid parameters") @setup_required @@ -74,4 +131,4 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): args = payload.model_dump(exclude_none=True) self.hit_testing_args_check(args) - return self.perform_hit_testing(dataset, args) + return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json") diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 8fb3699849..71ab1513ed 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -2,7 +2,6 @@ import logging from typing import Any from flask_restx import marshal -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -21,6 +20,7 @@ from core.errors.error import ( QuotaExceededError, ) from fields.hit_testing_fields import hit_testing_record_fields +from graphon.model_runtime.errors.invoke import InvokeError from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService @@ -38,6 +38,48 @@ class HitTestingPayload(BaseModel): class DatasetsHitTestingBase: + @staticmethod + def _normalize_hit_testing_query(query: Any) -> str: + """Return the user-visible query string from legacy and current response shapes.""" + if isinstance(query, str): + return query + + if isinstance(query, dict): + content = query.get("content") + if isinstance(content, str): + return content + + raise ValueError("Invalid hit testing query response") + + @staticmethod + def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]: + """Coerce nullable collection fields into lists before response validation.""" + if not isinstance(records, list): + return [] + + normalized_records: list[dict[str, Any]] = [] + for record in records: + if not isinstance(record, dict): + continue + + normalized_record = dict(record) + segment = normalized_record.get("segment") + if isinstance(segment, dict): + normalized_segment = dict(segment) + if normalized_segment.get("keywords") is None: + normalized_segment["keywords"] = [] + normalized_record["segment"] = normalized_segment + + if normalized_record.get("child_chunks") is None: + normalized_record["child_chunks"] = [] + + if normalized_record.get("files") is None: + normalized_record["files"] = [] + + normalized_records.append(normalized_record) + + return normalized_records + @staticmethod def get_and_validate_dataset(dataset_id: str): assert isinstance(current_user, Account) @@ -75,7 +117,12 @@ class DatasetsHitTestingBase: attachment_ids=args.get("attachment_ids"), limit=10, ) - return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} + return { + "query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")), + "records": DatasetsHitTestingBase._normalize_hit_testing_records( + marshal(response.get("records", []), hit_testing_record_fields) + ), + } except services.errors.index.IndexNotInitializedError: raise DatasetNotInitializedError() except ProviderTokenNotInitError as ex: diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 2e69ddc5ab..d966e1629e 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,9 +1,9 @@ from typing import Literal from flask_restx import Resource, marshal_with -from pydantic import BaseModel from werkzeug.exceptions import NotFound +from controllers.common.controller_schemas import MetadataUpdatePayload from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required @@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.metadata_service import MetadataService - -class MetadataUpdatePayload(BaseModel): - name: str - - register_schema_models( console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail ) diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index bdf83b991e..fd0a8b33bc 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -2,8 +2,6 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound @@ -12,6 +10,8 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 3549f9542d..b31d73f27d 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -4,7 +4,6 @@ from typing import Any, NoReturn from flask import Response, request from flask_restx import Resource, marshal, marshal_with -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -28,6 +27,7 @@ from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTE from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type +from graphon.variables.types import SegmentType from libs.login import current_user, login_required from models import Account from models.dataset import Pipeline diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index a8077d9eb0..ee146e8287 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,7 +4,6 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -41,6 +40,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from factories import variable_factory +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty from libs.login import current_account_with_tenant, current_user, login_required diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index a37077af42..ab660d9dc3 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,7 +1,6 @@ import logging from flask import request -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import services @@ -20,6 +19,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index eacd7332fe..ccdccceaa6 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,7 +2,6 @@ import logging from typing import Any, Literal from uuid import UUID -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,6 +25,7 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 0740dd0e24..2d9a997fbf 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,21 +1,24 @@ import logging +from datetime import datetime from typing import Any from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, computed_field, field_validator from sqlalchemy import and_, select from werkzeug.exceptions import BadRequest, Forbidden, NotFound -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from fields.installed_app_fields import app_fields, installed_app_fields, installed_app_list_fields +from fields.base import ResponseModel +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import App, InstalledApp, RecommendedApp +from models.model import IconType from services.account_service import TenantService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -36,22 +39,97 @@ class InstalledAppsListQuery(BaseModel): logger = logging.getLogger(__name__) -app_model = get_or_create_model("InstalledAppInfo", app_fields) +def _build_icon_url(icon_type: str | IconType | None, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE: + return None + return file_helpers.get_signed_file_url(icon) -installed_app_fields_copy = installed_app_fields.copy() -installed_app_fields_copy["app"] = fields.Nested(app_model) -installed_app_model = get_or_create_model("InstalledApp", installed_app_fields_copy) -installed_app_list_fields_copy = installed_app_list_fields.copy() -installed_app_list_fields_copy["installed_apps"] = fields.List(fields.Nested(installed_app_model)) -installed_app_list_model = get_or_create_model("InstalledAppList", installed_app_list_fields_copy) +def _safe_primitive(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool, datetime)): + return value + return None + + +class InstalledAppInfoResponse(ResponseModel): + id: str + name: str | None = None + mode: str | None = None + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + use_icon_as_answer_icon: bool | None = None + + @field_validator("mode", "icon_type", mode="before") + @classmethod + def _normalize_enum_like(cls, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def icon_url(self) -> str | None: + return _build_icon_url(self.icon_type, self.icon) + + +class InstalledAppResponse(ResponseModel): + id: str + app: InstalledAppInfoResponse + app_owner_tenant_id: str + is_pinned: bool + last_used_at: int | None = None + editable: bool + uninstallable: bool + + @field_validator("app", mode="before") + @classmethod + def _normalize_app(cls, value: Any) -> Any: + if isinstance(value, dict): + return value + return { + "id": _safe_primitive(getattr(value, "id", "")) or "", + "name": _safe_primitive(getattr(value, "name", None)), + "mode": _safe_primitive(getattr(value, "mode", None)), + "icon_type": _safe_primitive(getattr(value, "icon_type", None)), + "icon": _safe_primitive(getattr(value, "icon", None)), + "icon_background": _safe_primitive(getattr(value, "icon_background", None)), + "use_icon_as_answer_icon": _safe_primitive(getattr(value, "use_icon_as_answer_icon", None)), + } + + @field_validator("last_used_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class InstalledAppListResponse(ResponseModel): + installed_apps: list[InstalledAppResponse] + + +register_schema_models( + console_ns, + InstalledAppCreatePayload, + InstalledAppUpdatePayload, + InstalledAppsListQuery, + InstalledAppInfoResponse, + InstalledAppResponse, + InstalledAppListResponse, +) @console_ns.route("/installed-apps") class InstalledAppsListApi(Resource): @login_required @account_initialization_required - @marshal_with(installed_app_list_model) + @console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__]) def get(self): query = InstalledAppsListQuery.model_validate(request.args.to_dict()) current_user, current_tenant_id = current_account_with_tenant() @@ -125,7 +203,9 @@ class InstalledAppsListApi(Resource): ) ) - return {"installed_apps": installed_app_list} + return InstalledAppListResponse.model_validate( + {"installed_apps": installed_app_list}, from_attributes=True + ).model_dump(mode="json") @login_required @account_initialization_required diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 64d55d7ca3..209667d1d0 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,7 +2,6 @@ import logging from typing import Literal from flask import request -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -25,6 +24,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.enums import FeedbackRating diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index c9920c97cf..55bd679b48 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,66 +1,83 @@ +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, computed_field, field_validator from constants.languages import languages -from controllers.common.schema import get_or_create_model +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required -from libs.helper import AppIconUrlField +from fields.base import ResponseModel +from libs.helper import build_icon_url from libs.login import current_user, login_required from services.recommended_app_service import RecommendedAppService -app_fields = { - "id": fields.String, - "name": fields.String, - "mode": fields.String, - "icon": fields.String, - "icon_type": fields.String, - "icon_url": AppIconUrlField, - "icon_background": fields.String, -} - -app_model = get_or_create_model("RecommendedAppInfo", app_fields) - -recommended_app_fields = { - "app": fields.Nested(app_model, attribute="app"), - "app_id": fields.String, - "description": fields.String(attribute="description"), - "copyright": fields.String, - "privacy_policy": fields.String, - "custom_disclaimer": fields.String, - "category": fields.String, - "position": fields.Integer, - "is_listed": fields.Boolean, - "can_trial": fields.Boolean, -} - -recommended_app_model = get_or_create_model("RecommendedApp", recommended_app_fields) - -recommended_app_list_fields = { - "recommended_apps": fields.List(fields.Nested(recommended_app_model)), - "categories": fields.List(fields.String), -} - -recommended_app_list_model = get_or_create_model("RecommendedAppList", recommended_app_list_fields) - class RecommendedAppsQuery(BaseModel): language: str | None = Field(default=None) -console_ns.schema_model( - RecommendedAppsQuery.__name__, - RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"), +class RecommendedAppInfoResponse(ResponseModel): + id: str + name: str | None = None + mode: str | None = None + icon: str | None = None + icon_type: str | None = None + icon_background: str | None = None + + @staticmethod + def _normalize_enum_like(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("mode", "icon_type", mode="before") + @classmethod + def _normalize_enum_fields(cls, value: Any) -> str | None: + return cls._normalize_enum_like(value) + + @computed_field(return_type=str | None) # type: ignore[prop-decorator] + @property + def icon_url(self) -> str | None: + return build_icon_url(self.icon_type, self.icon) + + +class RecommendedAppResponse(ResponseModel): + app: RecommendedAppInfoResponse | None = None + app_id: str + description: str | None = None + copyright: str | None = None + privacy_policy: str | None = None + custom_disclaimer: str | None = None + category: str | None = None + position: int | None = None + is_listed: bool | None = None + can_trial: bool | None = None + + +class RecommendedAppListResponse(ResponseModel): + recommended_apps: list[RecommendedAppResponse] + categories: list[str] + + +register_schema_models( + console_ns, + RecommendedAppsQuery, + RecommendedAppInfoResponse, + RecommendedAppResponse, + RecommendedAppListResponse, ) @console_ns.route("/explore/apps") class RecommendedAppListApi(Resource): @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__]) + @console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__]) @login_required @account_initialization_required - @marshal_with(recommended_app_list_model) def get(self): # language args args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore @@ -72,7 +89,10 @@ class RecommendedAppListApi(Resource): else: language_prefix = languages[0] - return RecommendedAppService.get_recommended_apps_and_categories(language_prefix) + return RecommendedAppListResponse.model_validate( + RecommendedAppService.get_recommended_apps_and_categories(language_prefix), + from_attributes=True, + ).model_dump(mode="json") @console_ns.route("/explore/apps/") diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index e432574434..1456301a24 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -3,8 +3,6 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -61,6 +59,8 @@ from fields.workflow_fields import ( workflow_fields, workflow_partial_fields, ) +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user @@ -169,6 +169,7 @@ console_ns.schema_model( class TrialAppWorkflowRunApi(TrialAppResource): + @trial_feature_enable @console_ns.expect(console_ns.models[WorkflowRunRequest.__name__]) def post(self, trial_app): """ @@ -210,6 +211,7 @@ class TrialAppWorkflowRunApi(TrialAppResource): class TrialAppWorkflowTaskStopApi(TrialAppResource): + @trial_feature_enable def post(self, trial_app, task_id: str): """ Stop workflow task @@ -290,7 +292,6 @@ class TrialChatApi(TrialAppResource): class TrialMessageSuggestedQuestionApi(TrialAppResource): - @trial_feature_enable def get(self, trial_app, message_id): app_model = trial_app app_mode = AppMode.value_of(app_model.mode) @@ -470,7 +471,6 @@ class TrialCompletionApi(TrialAppResource): class TrialSitApi(Resource): """Resource for trial app sites.""" - @trial_feature_enable @get_app_model_with_trial(None) def get(self, app_model): """Retrieve app site info. @@ -492,7 +492,6 @@ class TrialSitApi(Resource): class TrialAppParameterApi(Resource): """Resource for app variables.""" - @trial_feature_enable @get_app_model_with_trial(None) def get(self, app_model): """Retrieve app parameters.""" @@ -521,7 +520,6 @@ class TrialAppParameterApi(Resource): class AppApi(Resource): - @trial_feature_enable @get_app_model_with_trial(None) @marshal_with(app_detail_with_site_model) def get(self, app_model): @@ -534,7 +532,6 @@ class AppApi(Resource): class AppWorkflowApi(Resource): - @trial_feature_enable @get_app_model_with_trial(None) @marshal_with(workflow_model) def get(self, app_model): @@ -547,7 +544,6 @@ class AppWorkflowApi(Resource): class DatasetListApi(Resource): - @trial_feature_enable @get_app_model_with_trial(None) def get(self, app_model): page = request.args.get("page", default=1, type=int) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index da88de6776..438cce4fd8 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,7 +1,5 @@ import logging -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError from controllers.common.controller_schemas import WorkflowRunPayload @@ -23,6 +21,8 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.login import current_account_with_tenant from models.model import AppMode, InstalledApp diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index efa46c9779..7a6356d052 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,15 +1,18 @@ +from datetime import datetime +from typing import Any + from flask import request -from flask_restx import Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, TypeAdapter, field_validator from constants import HIDDEN_VALUE -from fields.api_based_extension_fields import api_based_extension_fields +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService -from ..common.schema import register_schema_models +from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models from . import console_ns from .wraps import account_initialization_required, setup_required @@ -24,12 +27,52 @@ class APIBasedExtensionPayload(BaseModel): api_key: str = Field(description="API key for authentication") -register_schema_models(console_ns, APIBasedExtensionPayload) +class CodeBasedExtensionResponse(ResponseModel): + module: str = Field(description="Module name") + data: Any = Field(description="Extension data") -api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields) +def _mask_api_key(api_key: str) -> str: + if not api_key: + return api_key + if len(api_key) <= 8: + return api_key[0] + "******" + api_key[-1] + return api_key[:3] + "******" + api_key[-3:] -api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model)) + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class APIBasedExtensionResponse(ResponseModel): + id: str + name: str + api_endpoint: str + api_key: str + created_at: int | None = None + + @field_validator("api_key", mode="before") + @classmethod + def _normalize_api_key(cls, value: str) -> str: + return _mask_api_key(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse) +console_ns.schema_model( + "APIBasedExtensionListResponse", + TypeAdapter(list[APIBasedExtensionResponse]).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + + +def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, Any]: + return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json") @console_ns.route("/code-based-extension") @@ -40,10 +83,7 @@ class CodeBasedExtensionAPI(Resource): @console_ns.response( 200, "Success", - console_ns.model( - "CodeBasedExtensionResponse", - {"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")}, - ), + console_ns.models[CodeBasedExtensionResponse.__name__], ) @setup_required @login_required @@ -51,30 +91,34 @@ class CodeBasedExtensionAPI(Resource): def get(self): query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore - return {"module": query.module, "data": CodeBasedExtensionService.get_code_based_extension(query.module)} + return CodeBasedExtensionResponse( + module=query.module, + data=CodeBasedExtensionService.get_code_based_extension(query.module), + ).model_dump(mode="json") @console_ns.route("/api-based-extension") class APIBasedExtensionAPI(Resource): @console_ns.doc("get_api_based_extensions") @console_ns.doc(description="Get all API-based extensions for current tenant") - @console_ns.response(200, "Success", api_based_extension_list_model) + @console_ns.response(200, "Success", console_ns.models["APIBasedExtensionListResponse"]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def get(self): _, tenant_id = current_account_with_tenant() - return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + return [ + _serialize_api_based_extension(extension) + for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id) + ] @console_ns.doc("create_api_based_extension") @console_ns.doc(description="Create a new API-based extension") @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__]) - @console_ns.response(201, "Extension created successfully", api_based_extension_model) + @console_ns.response(201, "Extension created successfully", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def post(self): payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {}) _, current_tenant_id = current_account_with_tenant() @@ -86,7 +130,7 @@ class APIBasedExtensionAPI(Resource): api_key=payload.api_key, ) - return APIBasedExtensionService.save(extension_data) + return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data)) @console_ns.route("/api-based-extension/") @@ -94,26 +138,26 @@ class APIBasedExtensionDetailAPI(Resource): @console_ns.doc("get_api_based_extension") @console_ns.doc(description="Get API-based extension by ID") @console_ns.doc(params={"id": "Extension ID"}) - @console_ns.response(200, "Success", api_based_extension_model) + @console_ns.response(200, "Success", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def get(self, id): api_based_extension_id = str(id) _, tenant_id = current_account_with_tenant() - return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + return _serialize_api_based_extension( + APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) + ) @console_ns.doc("update_api_based_extension") @console_ns.doc(description="Update API-based extension") @console_ns.doc(params={"id": "Extension ID"}) @console_ns.expect(console_ns.models[APIBasedExtensionPayload.__name__]) - @console_ns.response(200, "Extension updated successfully", api_based_extension_model) + @console_ns.response(200, "Extension updated successfully", console_ns.models[APIBasedExtensionResponse.__name__]) @setup_required @login_required @account_initialization_required - @marshal_with(api_based_extension_model) def post(self, id): api_based_extension_id = str(id) _, current_tenant_id = current_account_with_tenant() @@ -128,7 +172,7 @@ class APIBasedExtensionDetailAPI(Resource): if payload.api_key != HIDDEN_VALUE: extension_data_from_db.api_key = payload.api_key - return APIBasedExtensionService.save(extension_data_from_db) + return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db)) @console_ns.doc("delete_api_based_extension") @console_ns.doc(description="Delete API-based extension") diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 845af37365..79b3e6cc9f 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -8,10 +8,10 @@ from collections.abc import Generator from flask import Response, jsonify, request from flask_restx import Resource -from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from controllers.common.human_input import HumanInputFormSubmitPayload from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError @@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.apps.message_generator import MessageGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App from models.enums import CreatorUserRole -from models.human_input import RecipientType from models.model import AppMode from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory @@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream logger = logging.getLogger(__name__) -class HumanInputFormSubmitPayload(BaseModel): - inputs: dict - action: str - - def _jsonify_form_definition(form: Form) -> Response: payload = form.get_definition().model_dump() payload["expiration_time"] = int(form.expiration_time.timestamp()) @@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource): if form.tenant_id != current_tenant_id: raise NotFoundError("App not found") + @staticmethod + def _ensure_console_recipient_type(form: Form) -> None: + if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE): + raise NotFoundError("form not found") + @setup_required @login_required @account_initialization_required @@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource): raise NotFoundError(f"form not found, token={form_token}") self._ensure_console_access(form) - + self._ensure_console_recipient_type(form) recipient_type = form.recipient_type - if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}: - raise NotFoundError(f"form not found, token={form_token}") # The type checker is not smart enought to validate the following invariant. # So we need to assert it manually. assert recipient_type is not None, "recipient_type cannot be None here." diff --git a/api/controllers/console/notification.py b/api/controllers/console/notification.py index 180167402a..5d46470173 100644 --- a/api/controllers/console/notification.py +++ b/api/controllers/console/notification.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from typing import TypedDict from flask import request @@ -13,6 +14,14 @@ from services.billing_service import BillingService _FALLBACK_LANG = "en-US" +class NotificationLangContent(TypedDict, total=False): + lang: str + title: str + subtitle: str + body: str + titlePicUrl: str + + class NotificationItemDict(TypedDict): notification_id: str | None frequency: str | None @@ -28,9 +37,11 @@ class NotificationResponseDict(TypedDict): notifications: list[NotificationItemDict] -def _pick_lang_content(contents: dict, lang: str) -> dict: +def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent: """Return the single LangContent for *lang*, falling back to English.""" - return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {}) + return ( + contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent()) + ) class DismissNotificationPayload(BaseModel): @@ -71,7 +82,7 @@ class NotificationApi(Resource): notifications: list[NotificationItemDict] = [] for notification in result.get("notifications") or []: - contents: dict = notification.get("contents") or {} + contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {} lang_content = _pick_lang_content(contents, lang) item: NotificationItemDict = { "notification_id": notification.get("notificationId"), diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 551c86fd82..2a46d2250a 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,7 +2,6 @@ import urllib.parse import httpx from flask_restx import Resource -from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field import services @@ -16,6 +15,7 @@ from controllers.console import console_ns from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from libs.login import current_account_with_tenant, login_required from services.file_service import FileService diff --git a/api/controllers/console/socketio/__init__.py b/api/controllers/console/socketio/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/controllers/console/socketio/__init__.py @@ -0,0 +1 @@ + diff --git a/api/controllers/console/socketio/workflow.py b/api/controllers/console/socketio/workflow.py new file mode 100644 index 0000000000..b4f03593fd --- /dev/null +++ b/api/controllers/console/socketio/workflow.py @@ -0,0 +1,108 @@ +import logging +from collections.abc import Callable +from typing import cast + +from flask import Request as FlaskRequest + +from extensions.ext_socketio import sio +from libs.passport import PassportService +from libs.token import extract_access_token +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository +from services.account_service import AccountService +from services.workflow_collaboration_service import WorkflowCollaborationService + +repository = WorkflowCollaborationRepository() +collaboration_service = WorkflowCollaborationService(repository, sio) + + +def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]: + return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event)) + + +@_sio_on("connect") +def socket_connect(sid, environ, auth): + """ + WebSocket connect event, do authentication here. + """ + try: + request_environ = FlaskRequest(environ) + token = extract_access_token(request_environ) + except Exception: + logging.exception("Failed to extract token") + token = None + + if not token: + logging.warning("Socket connect rejected: missing token (sid=%s)", sid) + return False + + try: + decoded = PassportService().verify(token) + user_id = decoded.get("user_id") + if not user_id: + logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid) + return False + + with sio.app.app_context(): + user = AccountService.load_logged_in_account(account_id=user_id) + if not user: + logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid) + return False + if not user.has_edit_permission: + logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid) + return False + + collaboration_service.save_socket_identity(sid, user) + return True + + except Exception: + logging.exception("Socket authentication failed") + return False + + +@_sio_on("user_connect") +def handle_user_connect(sid, data): + """ + Handle user connect event. Each session (tab) is treated as an independent collaborator. + """ + workflow_id = data.get("workflow_id") + if not workflow_id: + return {"msg": "workflow_id is required"}, 400 + + result = collaboration_service.authorize_and_join_workflow_room(workflow_id, sid) + if not result: + return {"msg": "unauthorized"}, 401 + + user_id, is_leader = result + return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader} + + +@_sio_on("disconnect") +def handle_disconnect(sid): + """ + Handle session disconnect event. Remove the specific session from online users. + """ + collaboration_service.disconnect_session(sid) + + +@_sio_on("collaboration_event") +def handle_collaboration_event(sid, data): + """ + Handle general collaboration events, include: + 1. mouse_move + 2. vars_and_features_update + 3. sync_request (ask leader to update graph) + 4. app_state_update + 5. mcp_server_update + 6. workflow_update + 7. comments_update + 8. node_panel_presence + """ + return collaboration_service.relay_collaboration_event(sid, data) + + +@_sio_on("graph_event") +def handle_graph_event(sid, data): + """ + Handle graph events - simple broadcast relay. + """ + return collaboration_service.relay_graph_event(sid, data) diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 39b84d3869..b9e876c906 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -1,13 +1,14 @@ from typing import Literal from flask import request -from flask_restx import Namespace, Resource, fields, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource +from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from fields.base import ResponseModel from libs.login import current_account_with_tenant, login_required from models.enums import TagType from services.tag_service import ( @@ -18,17 +19,6 @@ from services.tag_service import ( UpdateTagPayload, ) -dataset_tag_fields = { - "id": fields.String, - "name": fields.String, - "type": fields.String, - "binding_count": fields.String, -} - - -def build_dataset_tag_fields(api_or_ns: Namespace): - return api_or_ns.model("DataSetTag", dataset_tag_fields) - class TagBasePayload(BaseModel): name: str = Field(description="Tag name", min_length=1, max_length=50) @@ -42,7 +32,7 @@ class TagBindingPayload(BaseModel): class TagBindingRemovePayload(BaseModel): - tag_id: str = Field(description="Tag ID to remove") + tag_ids: list[str] = Field(description="Tag IDs to remove", min_length=1) target_id: str = Field(description="Target ID to unbind tag from") type: TagType = Field(description="Tag type") @@ -52,12 +42,36 @@ class TagListQueryParam(BaseModel): keyword: str | None = Field(None, description="Search keyword") +class TagResponse(ResponseModel): + id: str + name: str + type: str | None = None + binding_count: str | None = None + + @field_validator("type", mode="before") + @classmethod + def normalize_type(cls, value: TagType | str | None) -> str | None: + if value is None: + return None + if isinstance(value, TagType): + return value.value + return value + + @field_validator("binding_count", mode="before") + @classmethod + def normalize_binding_count(cls, value: int | str | None) -> str | None: + if value is None: + return None + return str(value) + + register_schema_models( console_ns, TagBasePayload, TagBindingPayload, TagBindingRemovePayload, TagListQueryParam, + TagResponse, ) @@ -69,14 +83,18 @@ class TagListApi(Resource): @console_ns.doc( params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."} ) - @marshal_with(dataset_tag_fields) + @console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])}) def get(self): _, current_tenant_id = current_account_with_tenant() raw_args = request.args.to_dict() param = TagListQueryParam.model_validate(raw_args) tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) - return tags, 200 + serialized_tags = [ + TagResponse.model_validate(tag, from_attributes=True).model_dump(mode="json") for tag in tags + ] + + return serialized_tags, 200 @console_ns.expect(console_ns.models[TagBasePayload.__name__]) @setup_required @@ -91,7 +109,9 @@ class TagListApi(Resource): payload = TagBasePayload.model_validate(console_ns.payload or {}) tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type)) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + response = TagResponse.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + ).model_dump(mode="json") return response, 200 @@ -114,7 +134,9 @@ class TagUpdateDeleteApi(Resource): binding_count = TagService.get_tag_binding_count(tag_id) - response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + response = TagResponse.model_validate( + {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + ).model_dump(mode="json") return response, 200 @@ -130,41 +152,68 @@ class TagUpdateDeleteApi(Resource): return "", 204 -@console_ns.route("/tag-bindings/create") -class TagBindingCreateApi(Resource): +def _require_tag_binding_edit_permission() -> None: + """ + Ensure the current account can edit tag bindings. + + Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant. + """ + current_user, _ = current_account_with_tenant() + # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + if not (current_user.has_edit_permission or current_user.is_dataset_editor): + raise Forbidden() + + +def _create_tag_bindings() -> tuple[dict[str, str], int]: + _require_tag_binding_edit_permission() + + payload = TagBindingPayload.model_validate(console_ns.payload or {}) + TagService.save_tag_binding( + TagBindingCreatePayload( + tag_ids=payload.tag_ids, + target_id=payload.target_id, + type=payload.type, + ) + ) + return {"result": "success"}, 200 + + +def _remove_tag_bindings() -> tuple[dict[str, str], int]: + _require_tag_binding_edit_permission() + + payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) + TagService.delete_tag_binding( + TagBindingDeletePayload( + tag_ids=payload.tag_ids, + target_id=payload.target_id, + type=payload.type, + ) + ) + return {"result": "success"}, 200 + + +@console_ns.route("/tag-bindings") +class TagBindingCollectionApi(Resource): + """Canonical collection resource for tag binding creation.""" + + @console_ns.doc("create_tag_binding") @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() - - payload = TagBindingPayload.model_validate(console_ns.payload or {}) - TagService.save_tag_binding( - TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type) - ) - - return {"result": "success"}, 200 + return _create_tag_bindings() @console_ns.route("/tag-bindings/remove") -class TagBindingDeleteApi(Resource): +class TagBindingRemoveApi(Resource): + """Batch resource for tag binding deletion.""" + + @console_ns.doc("remove_tag_bindings") + @console_ns.doc(description="Remove one or more tag bindings from a target.") @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - if not (current_user.has_edit_permission or current_user.is_dataset_editor): - raise Forbidden() - - payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) - TagService.delete_tag_binding( - TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type) - ) - - return {"result": "success"}, 200 + return _remove_tag_bindings() diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py index 60f712e476..59dd29fdac 100644 --- a/api/controllers/console/workspace/__init__.py +++ b/api/controllers/console/workspace/__init__.py @@ -35,22 +35,24 @@ def plugin_permission_required( return view(*args, **kwargs) if install_required: - if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY: - raise Forbidden() - if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS: - if not user.is_admin_or_owner: + match permission.install_permission: + case TenantPluginPermission.InstallPermission.NOBODY: raise Forbidden() - if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE: - pass + case TenantPluginPermission.InstallPermission.ADMINS: + if not user.is_admin_or_owner: + raise Forbidden() + case TenantPluginPermission.InstallPermission.EVERYONE: + pass if debug_required: - if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY: - raise Forbidden() - if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS: - if not user.is_admin_or_owner: + match permission.debug_permission: + case TenantPluginPermission.DebugPermission.NOBODY: raise Forbidden() - if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE: - pass + case TenantPluginPermission.DebugPermission.ADMINS: + if not user.is_admin_or_owner: + raise Forbidden() + case TenantPluginPermission.DebugPermission.EVERYONE: + pass return view(*args, **kwargs) diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 626d330e9d..d69a59ecb7 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,14 +1,14 @@ from __future__ import annotations from datetime import datetime -from typing import Literal +from typing import Any, Literal import pytz from flask import request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound from configs import dify_config from constants.languages import supported_language @@ -38,12 +38,16 @@ from controllers.console.wraps import ( setup_required, ) from extensions.ext_database import db +from fields.base import ResponseModel from fields.member_fields import Account as AccountResponse +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now -from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone +from libs.helper import EmailStr, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required from models import AccountIntegrate, InvitationCode from models.account import AccountStatus, InvitationCodeStatus +from models.enums import CreatorUserRole +from models.model import UploadFile from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -75,6 +79,10 @@ class AccountAvatarPayload(BaseModel): avatar: str +class AccountAvatarQuery(BaseModel): + avatar: str = Field(..., description="Avatar file ID") + + class AccountInterfaceLanguagePayload(BaseModel): interface_language: str @@ -160,6 +168,7 @@ def reg(cls: type[BaseModel]): reg(AccountInitPayload) reg(AccountNamePayload) reg(AccountAvatarPayload) +reg(AccountAvatarQuery) reg(AccountInterfaceLanguagePayload) reg(AccountInterfaceThemePayload) reg(AccountTimezonePayload) @@ -175,21 +184,61 @@ reg(CheckEmailUniquePayload) register_schema_models(console_ns, AccountResponse) -def _serialize_account(account) -> dict: +def _serialize_account(account) -> dict[str, Any]: return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json") -integrate_fields = { - "provider": fields.String, - "created_at": TimestampField, - "is_bound": fields.Boolean, - "link": fields.String, -} +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value -integrate_model = console_ns.model("AccountIntegrate", integrate_fields) -integrate_list_model = console_ns.model( - "AccountIntegrateList", - {"data": fields.List(fields.Nested(integrate_model))}, + +class AccountIntegrateResponse(ResponseModel): + provider: str + created_at: int | None = None + is_bound: bool + link: str | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AccountIntegrateListResponse(ResponseModel): + data: list[AccountIntegrateResponse] + + +class EducationVerifyResponse(ResponseModel): + token: str | None = None + + +class EducationStatusResponse(ResponseModel): + result: bool | None = None + is_student: bool | None = None + expire_at: int | None = None + allow_refresh: bool | None = None + + @field_validator("expire_at", mode="before") + @classmethod + def _normalize_expire_at(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class EducationAutocompleteResponse(ResponseModel): + data: list[str] = Field(default_factory=list) + curr_page: int | None = None + has_next: bool | None = None + + +register_schema_models( + console_ns, + AccountIntegrateResponse, + AccountIntegrateListResponse, + EducationVerifyResponse, + EducationStatusResponse, + EducationAutocompleteResponse, ) @@ -269,6 +318,33 @@ class AccountNameApi(Resource): @console_ns.route("/account/avatar") class AccountAvatarApi(Resource): + @console_ns.expect(console_ns.models[AccountAvatarQuery.__name__]) + @console_ns.doc("get_account_avatar") + @console_ns.doc(description="Get account avatar url") + @setup_required + @login_required + @account_initialization_required + def get(self): + current_user, current_tenant_id = current_account_with_tenant() + args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + avatar = args.avatar + + if avatar.startswith(("http://", "https://")): + return {"avatar_url": avatar} + + upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1)) + if upload_file is None: + raise NotFound("Avatar file not found") + + if upload_file.tenant_id != current_tenant_id: + raise NotFound("Avatar file not found") + + if upload_file.created_by_role != CreatorUserRole.ACCOUNT or upload_file.created_by != current_user.id: + raise NotFound("Avatar file not found") + + avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) + return {"avatar_url": avatar_url} + @console_ns.expect(console_ns.models[AccountAvatarPayload.__name__]) @setup_required @login_required @@ -360,7 +436,7 @@ class AccountIntegrateApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(integrate_list_model) + @console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__]) def get(self): account, _ = current_account_with_tenant() @@ -396,7 +472,9 @@ class AccountIntegrateApi(Resource): } ) - return {"data": integrate_data} + return AccountIntegrateListResponse( + data=[AccountIntegrateResponse.model_validate(item) for item in integrate_data] + ).model_dump(mode="json") @console_ns.route("/account/delete/verify") @@ -448,31 +526,22 @@ class AccountDeleteUpdateFeedbackApi(Resource): @console_ns.route("/account/education/verify") class EducationVerifyApi(Resource): - verify_fields = { - "token": fields.String, - } - @setup_required @login_required @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(verify_fields) + @console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__]) def get(self): account, _ = current_account_with_tenant() - return BillingService.EducationIdentity.verify(account.id, account.email) + return EducationVerifyResponse.model_validate( + BillingService.EducationIdentity.verify(account.id, account.email) or {} + ).model_dump(mode="json") @console_ns.route("/account/education") class EducationApi(Resource): - status_fields = { - "result": fields.Boolean, - "is_student": fields.Boolean, - "expire_at": TimestampField, - "allow_refresh": fields.Boolean, - } - @console_ns.expect(console_ns.models[EducationActivatePayload.__name__]) @setup_required @login_required @@ -492,37 +561,33 @@ class EducationApi(Resource): @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(status_fields) + @console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__]) def get(self): account, _ = current_account_with_tenant() - res = BillingService.EducationIdentity.status(account.id) + res = BillingService.EducationIdentity.status(account.id) or {} # convert expire_at to UTC timestamp from isoformat if res and "expire_at" in res: res["expire_at"] = datetime.fromisoformat(res["expire_at"]).astimezone(pytz.utc) - return res + return EducationStatusResponse.model_validate(res).model_dump(mode="json") @console_ns.route("/account/education/autocomplete") class EducationAutoCompleteApi(Resource): - data_fields = { - "data": fields.List(fields.String), - "curr_page": fields.Integer, - "has_next": fields.Boolean, - } - @console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__]) @setup_required @login_required @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - @marshal_with(data_fields) + @console_ns.response(200, "Success", console_ns.models[EducationAutocompleteResponse.__name__]) def get(self): payload = request.args.to_dict(flat=True) args = EducationAutocompleteQuery.model_validate(payload) - return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) + return EducationAutocompleteResponse.model_validate( + BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) or {} + ).model_dump(mode="json") @console_ns.route("/account/change-email") @@ -548,13 +613,25 @@ class ChangeEmailSendEmailApi(Resource): account = None user_email = None email_for_sending = args.email.lower() - if args.phase is not None and args.phase == "new_email": + # Default to the initial phase; any legacy/unexpected client input is + # coerced back to `old_email` so we never trust the caller to declare + # later phases without a verified predecessor token. + send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD + if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW: + send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW if args.token is None: raise InvalidTokenError() reset_data = AccountService.get_change_email_data(args.token) if reset_data is None: raise InvalidTokenError() + + # The token used to request a new-email code must come from the + # old-email verification step. This prevents the bypass described + # in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here. + token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED: + raise InvalidTokenError() user_email = reset_data.get("email", "") if user_email.lower() != current_user.email.lower(): @@ -562,8 +639,7 @@ class ChangeEmailSendEmailApi(Resource): user_email = current_user.email else: - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(args.email) if account is None: raise AccountNotFound() email_for_sending = account.email @@ -574,7 +650,7 @@ class ChangeEmailSendEmailApi(Resource): email=email_for_sending, old_email=user_email, language=language, - phase=args.phase, + phase=send_phase, ) return {"result": "success", "data": token} @@ -609,12 +685,31 @@ class ChangeEmailCheckApi(Resource): AccountService.add_change_email_error_rate_limit(user_email) raise EmailCodeError() + # Only advance tokens that were minted by the matching send-code step; + # refuse tokens that have already progressed or lack a phase marker so + # the chain `old_email -> old_email_verified -> new_email -> new_email_verified` + # is strictly enforced. + phase_transitions = { + AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } + token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if not isinstance(token_phase, str): + raise InvalidTokenError() + refreshed_phase = phase_transitions.get(token_phase) + if refreshed_phase is None: + raise InvalidTokenError() + # Verified, revoke the first token AccountService.revoke_change_email_token(args.token) - # Refresh token data by generating a new token + # Refresh token data by generating a new token that carries the + # upgraded phase so later steps can check it. _, new_token = AccountService.generate_change_email_token( - user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={} + user_email, + code=args.code, + old_email=token_data.get("old_email"), + additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase}, ) AccountService.reset_change_email_error_rate_limit(user_email) @@ -644,13 +739,29 @@ class ChangeEmailResetApi(Resource): if not reset_data: raise InvalidTokenError() - AccountService.revoke_change_email_token(args.token) + # Only tokens that completed both verification phases may be used to + # change the email. This closes GHSA-4q3w-q5mc-45rq where a token from + # the initial send-code step could be replayed directly here. + token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY) + if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED: + raise InvalidTokenError() + + # Bind the new email to the token that was mailed and verified, so a + # verified token cannot be reused with a different `new_email` value. + token_email = reset_data.get("email") + normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email + if normalized_token_email != normalized_new_email: + raise InvalidTokenError() old_email = reset_data.get("old_email", "") current_user, _ = current_account_with_tenant() if current_user.email.lower() != old_email.lower(): raise AccountNotFound() + # Revoke only after all checks pass so failed attempts don't burn a + # legitimately verified token. + AccountService.revoke_change_email_token(args.token) + updated_account = AccountService.update_account_email(current_user, email=normalized_new_email) AccountService.send_change_email_completed_notify_email( diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 3fdcbc4710..764f488755 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,8 +1,8 @@ from flask_restx import Resource, fields -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index b6b9deb1f9..d4be07382a 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -1,14 +1,22 @@ +"""Console workspace endpoint controllers. + +This module exposes workspace-scoped plugin endpoint management APIs. The +canonical write routes follow resource-oriented paths, while the historical +verb-based aliases stay available as deprecated resources so OpenAPI metadata +marks only the legacy paths as deprecated. +""" + from typing import Any from flask import request from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService @@ -25,7 +33,12 @@ class EndpointIdPayload(BaseModel): endpoint_id: str -class EndpointUpdatePayload(EndpointIdPayload): +class EndpointUpdatePayload(BaseModel): + settings: dict[str, Any] + name: str = Field(min_length=1) + + +class LegacyEndpointUpdatePayload(EndpointIdPayload): settings: dict[str, Any] name: str = Field(min_length=1) @@ -76,6 +89,7 @@ register_schema_models( EndpointCreatePayload, EndpointIdPayload, EndpointUpdatePayload, + LegacyEndpointUpdatePayload, EndpointListQuery, EndpointListForPluginQuery, EndpointCreateResponse, @@ -88,8 +102,60 @@ register_schema_models( ) -@console_ns.route("/workspaces/current/endpoints/create") -class EndpointCreateApi(Resource): +def _create_endpoint() -> dict[str, bool]: + """Create a plugin endpoint for the current workspace.""" + user, tenant_id = current_account_with_tenant() + + args = EndpointCreatePayload.model_validate(console_ns.payload) + + try: + return { + "success": EndpointService.create_endpoint( + tenant_id=tenant_id, + user_id=user.id, + plugin_unique_identifier=args.plugin_unique_identifier, + name=args.name, + settings=args.settings, + ) + } + except PluginPermissionDeniedError as e: + raise ValueError(e.description) from e + + +def _update_endpoint(endpoint_id: str) -> dict[str, bool]: + """Update a plugin endpoint identified by the canonical path parameter.""" + user, tenant_id = current_account_with_tenant() + + args = EndpointUpdatePayload.model_validate(console_ns.payload) + + return { + "success": EndpointService.update_endpoint( + tenant_id=tenant_id, + user_id=user.id, + endpoint_id=endpoint_id, + name=args.name, + settings=args.settings, + ) + } + + +def _delete_endpoint(endpoint_id: str) -> dict[str, bool]: + """Delete a plugin endpoint identified by the canonical path parameter.""" + user, tenant_id = current_account_with_tenant() + + return { + "success": EndpointService.delete_endpoint( + tenant_id=tenant_id, + user_id=user.id, + endpoint_id=endpoint_id, + ) + } + + +@console_ns.route("/workspaces/current/endpoints") +class EndpointCollectionApi(Resource): + """Canonical collection resource for endpoint creation.""" + @console_ns.doc("create_endpoint") @console_ns.doc(description="Create a new plugin endpoint") @console_ns.expect(console_ns.models[EndpointCreatePayload.__name__]) @@ -104,22 +170,33 @@ class EndpointCreateApi(Resource): @is_admin_or_owner_required @account_initialization_required def post(self): - user, tenant_id = current_account_with_tenant() + return _create_endpoint() - args = EndpointCreatePayload.model_validate(console_ns.payload) - try: - return { - "success": EndpointService.create_endpoint( - tenant_id=tenant_id, - user_id=user.id, - plugin_unique_identifier=args.plugin_unique_identifier, - name=args.name, - settings=args.settings, - ) - } - except PluginPermissionDeniedError as e: - raise ValueError(e.description) from e +@console_ns.route("/workspaces/current/endpoints/create") +class DeprecatedEndpointCreateApi(Resource): + """Deprecated verb-based alias for endpoint creation.""" + + @console_ns.doc("create_endpoint_deprecated") + @console_ns.doc(deprecated=True) + @console_ns.doc( + description=( + "Deprecated legacy alias for creating a plugin endpoint. Use POST /workspaces/current/endpoints instead." + ) + ) + @console_ns.expect(console_ns.models[EndpointCreatePayload.__name__]) + @console_ns.response( + 200, + "Endpoint created successfully", + console_ns.models[EndpointCreateResponse.__name__], + ) + @console_ns.response(403, "Admin privileges required") + @setup_required + @login_required + @is_admin_or_owner_required + @account_initialization_required + def post(self): + return _create_endpoint() @console_ns.route("/workspaces/current/endpoints/list") @@ -190,10 +267,56 @@ class EndpointListForSinglePluginApi(Resource): ) -@console_ns.route("/workspaces/current/endpoints/delete") -class EndpointDeleteApi(Resource): +@console_ns.route("/workspaces/current/endpoints/") +class EndpointItemApi(Resource): + """Canonical item resource for endpoint updates and deletion.""" + @console_ns.doc("delete_endpoint") @console_ns.doc(description="Delete a plugin endpoint") + @console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}}) + @console_ns.response( + 200, + "Endpoint deleted successfully", + console_ns.models[EndpointDeleteResponse.__name__], + ) + @console_ns.response(403, "Admin privileges required") + @setup_required + @login_required + @is_admin_or_owner_required + @account_initialization_required + def delete(self, id: str): + return _delete_endpoint(endpoint_id=id) + + @console_ns.doc("update_endpoint") + @console_ns.doc(description="Update a plugin endpoint") + @console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__]) + @console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}}) + @console_ns.response( + 200, + "Endpoint updated successfully", + console_ns.models[EndpointUpdateResponse.__name__], + ) + @console_ns.response(403, "Admin privileges required") + @setup_required + @login_required + @is_admin_or_owner_required + @account_initialization_required + def patch(self, id: str): + return _update_endpoint(endpoint_id=id) + + +@console_ns.route("/workspaces/current/endpoints/delete") +class DeprecatedEndpointDeleteApi(Resource): + """Deprecated verb-based alias for endpoint deletion.""" + + @console_ns.doc("delete_endpoint_deprecated") + @console_ns.doc(deprecated=True) + @console_ns.doc( + description=( + "Deprecated legacy alias for deleting a plugin endpoint. " + "Use DELETE /workspaces/current/endpoints/{id} instead." + ) + ) @console_ns.expect(console_ns.models[EndpointIdPayload.__name__]) @console_ns.response( 200, @@ -206,22 +329,23 @@ class EndpointDeleteApi(Resource): @is_admin_or_owner_required @account_initialization_required def post(self): - user, tenant_id = current_account_with_tenant() - args = EndpointIdPayload.model_validate(console_ns.payload) - - return { - "success": EndpointService.delete_endpoint( - tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id - ) - } + return _delete_endpoint(endpoint_id=args.endpoint_id) @console_ns.route("/workspaces/current/endpoints/update") -class EndpointUpdateApi(Resource): - @console_ns.doc("update_endpoint") - @console_ns.doc(description="Update a plugin endpoint") - @console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__]) +class DeprecatedEndpointUpdateApi(Resource): + """Deprecated verb-based alias for endpoint updates.""" + + @console_ns.doc("update_endpoint_deprecated") + @console_ns.doc(deprecated=True) + @console_ns.doc( + description=( + "Deprecated legacy alias for updating a plugin endpoint. " + "Use PATCH /workspaces/current/endpoints/{id} instead." + ) + ) + @console_ns.expect(console_ns.models[LegacyEndpointUpdatePayload.__name__]) @console_ns.response( 200, "Endpoint updated successfully", @@ -233,19 +357,8 @@ class EndpointUpdateApi(Resource): @is_admin_or_owner_required @account_initialization_required def post(self): - user, tenant_id = current_account_with_tenant() - - args = EndpointUpdatePayload.model_validate(console_ns.payload) - - return { - "success": EndpointService.update_endpoint( - tenant_id=tenant_id, - user_id=user.id, - endpoint_id=args.endpoint_id, - name=args.name, - settings=args.settings, - ) - } + args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload) + return _update_endpoint(endpoint_id=args.endpoint_id) @console_ns.route("/workspaces/current/endpoints/enable") diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index e4cfca9fa4..2a6f37aec8 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,12 +1,12 @@ from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index cbb9677309..4b10561fdb 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -3,13 +3,13 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 9182dbb510..b2d07ff8f9 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -3,14 +3,14 @@ from typing import Any, cast from flask import request from flask_restx import Resource -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService @@ -465,7 +465,7 @@ class ModelProviderModelDisableApi(Resource): class ParserValidate(BaseModel): model: str model_type: ModelType - credentials: dict + credentials: dict[str, Any] console_ns.schema_model( diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index aa674a63b3..b3e344ccea 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -4,7 +4,6 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -15,6 +14,7 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c9956501e2..34c9534de8 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -5,7 +5,6 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Forbidden @@ -28,6 +27,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required from models.provider_ids import ToolProviderID @@ -1131,6 +1131,14 @@ class ToolMCPAuthApi(Resource): with sessionmaker(db.engine).begin() as session: service = MCPToolManageService(session=session) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) + parsed = urlparse(server_url) + sanitized_url = f"{parsed.scheme}://{parsed.hostname}{parsed.path}" + logger.warning( + "MCP authorization failed for provider %s (url=%s)", + provider_id, + sanitized_url, + exc_info=True, + ) raise ValueError(f"Failed to connect to MCP server: {e}") from e diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 7a28a09861..d11b66244f 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -3,7 +3,6 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, Forbidden @@ -16,6 +15,7 @@ from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_user, login_required from models.account import Account from models.provider_ids import TriggerProviderID diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 42874e6033..565099db61 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,8 +1,9 @@ import logging +from datetime import datetime from flask import request -from flask_restx import Resource, fields, marshal, marshal_with -from pydantic import BaseModel, Field +from flask_restx import Resource, fields, marshal +from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from werkzeug.exceptions import Unauthorized @@ -26,6 +27,7 @@ from controllers.console.wraps import ( ) from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from fields.base import ResponseModel from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.account import Tenant, TenantCustomConfigDict, TenantStatus @@ -58,6 +60,37 @@ class WorkspaceInfoPayload(BaseModel): name: str +class TenantInfoResponse(ResponseModel): + id: str + name: str | None = None + plan: str | None = None + status: str | None = None + created_at: int | None = None + role: str | None = None + in_trial: bool | None = None + trial_end_reason: str | None = None + custom_config: dict | None = None + trial_credits: int | None = None + trial_credits_used: int | None = None + next_credit_reset_date: int | None = None + + @field_validator("plan", "status", "trial_end_reason", mode="before") + @classmethod + def _normalize_enum_like(cls, value): + if value is None: + return None + if isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None): + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + def reg(cls: type[BaseModel]): console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) @@ -66,6 +99,7 @@ reg(WorkspaceListQuery) reg(SwitchWorkspacePayload) reg(WorkspaceCustomConfigPayload) reg(WorkspaceInfoPayload) +reg(TenantInfoResponse) provider_fields = { "provider_name": fields.String, @@ -180,7 +214,7 @@ class TenantApi(Resource): @setup_required @login_required @account_initialization_required - @marshal_with(tenant_fields) + @console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__]) def post(self): if request.path == "/info": logger.warning("Deprecated URL /info was used.") @@ -200,7 +234,13 @@ class TenantApi(Resource): else: raise Unauthorized("workspace is archived") - return WorkspaceService.get_tenant_info(tenant), 200 + return ( + TenantInfoResponse.model_validate( + WorkspaceService.get_tenant_info(tenant), + from_attributes=True, + ).model_dump(mode="json"), + 200, + ) @console_ns.route("/workspaces/switch") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 4b5fb7ca5b..ef2931ce9b 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -20,7 +20,7 @@ from models.account import AccountStatus from models.dataset import RateLimitLog from models.model import DifySetup from services.feature_service import FeatureService, LicenseStatus -from services.operation_service import OperationService +from services.operation_service import OperationService, UtmInfo from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout @@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]: utm_info = request.cookies.get("utm_info") if utm_info: - utm_info_dict: dict = json.loads(utm_info) + utm_info_dict: UtmInfo = json.loads(utm_info) OperationService.record_utm(current_tenant_id, utm_info_dict) return view(*args, **kwargs) diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py index 6c15f9aa8b..915a11dcdd 100644 --- a/api/controllers/inner_api/app/dsl.py +++ b/api/controllers/inner_api/app/dsl.py @@ -9,7 +9,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from controllers.common.schema import register_schema_model from controllers.console.wraps import setup_required @@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource): account.set_tenant_id(workspace_id) - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: dsl_service = AppDslService(session) result = dsl_service.import_app( account=account, @@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource): name=args.name, description=args.description, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() if result.status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 83c8fa02fe..72cab3de73 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,5 +1,4 @@ from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns @@ -30,6 +29,7 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.signature import get_signed_file_url_for_plugin +from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 1d378c754c..2f309262cb 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -20,10 +20,13 @@ class TenantUserPayload(BaseModel): def get_user(tenant_id: str, user_id: str | None) -> EndUser: """ - Get current user + Get current user. NOTE: user_id is not trusted, it could be maliciously set to any value. - As a result, it could only be considered as an end user id. + As a result, it could only be considered as an end user id. Even when a + concrete end-user ID is supplied, lookups must stay tenant-scoped so one + tenant cannot bind another tenant's user record into the plugin request + context. """ if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID @@ -42,7 +45,14 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: .limit(1) ) else: - user_model = session.get(EndUser, user_id) + user_model = session.scalar( + select(EndUser) + .where( + EndUser.id == user_id, + EndUser.tenant_id == tenant_id, + ) + .limit(1) + ) if not user_model: user_model = EndUser( @@ -94,10 +104,9 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]: def plugin_data[**P, R]( - view: Callable[P, R] | None = None, *, payload_type: type[BaseModel], -) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: +) -> Callable[[Callable[P, R]], Callable[P, R]]: def decorator(view_func: Callable[P, R]) -> Callable[P, R]: @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: @@ -116,7 +125,4 @@ def plugin_data[**P, R]( return decorated_view - if view is None: - return decorator - else: - return decorator(view) + return decorator diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index d2ce0ea543..f652bbc581 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,7 +2,6 @@ from typing import Any, Union from flask import Response from flask_restx import Resource -from graphon.variables.input_entities import VariableEntity from pydantic import BaseModel, Field, ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -12,6 +11,7 @@ from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity, VariableEntityType from libs import helper from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser @@ -158,14 +158,20 @@ class MCPAppApi(Resource): except ValidationError as e: raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}") - def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]: + def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]: """Convert raw user input form to VariableEntity objects""" return [self._create_variable_entity(item) for item in raw_form] - def _create_variable_entity(self, item: dict) -> VariableEntity: + def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity: """Create a single VariableEntity from raw form item""" - variable_type = item.get("type", "") or list(item.keys())[0] - variable = item[variable_type] + variable_type_raw: str = item.get("type", "") or list(item.keys())[0] + try: + variable_type = VariableEntityType(variable_type_raw) + except ValueError as e: + raise MCPRequestError( + mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}" + ) from e + variable = item[variable_type_raw] return VariableEntity( type=variable_type, @@ -178,7 +184,7 @@ class MCPAppApi(Resource): json_schema=variable.get("json_schema"), ) - def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification: + def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification: """Parse and validate MCP request""" try: return mcp_types.ClientRequest.model_validate(args) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 4f7f7d9a98..182631e8f5 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -23,9 +23,11 @@ from .app import ( conversation, file, file_preview, + human_input_form, message, site, workflow, + workflow_events, ) from .dataset import ( dataset, @@ -50,6 +52,7 @@ __all__ = [ "file", "file_preview", "hit_testing", + "human_input_form", "index", "message", "metadata", @@ -58,6 +61,7 @@ __all__ = [ "segment", "site", "workflow", + "workflow_events", ] api.add_namespace(service_api_ns) diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index c22190cbc9..00bb9aa463 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client from fields.annotation_fields import Annotation, AnnotationList from models.model import App -from services.annotation_service import AppAnnotationService +from services.annotation_service import ( + AppAnnotationService, + EnableAnnotationArgs, + InsertAnnotationArgs, + UpdateAnnotationArgs, +) class AnnotationCreatePayload(BaseModel): @@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource): @validate_app_token def post(self, app_model: App, action: Literal["enable", "disable"]): """Enable or disable annotation reply feature.""" - args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump() + payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}) match action: case "enable": - result = AppAnnotationService.enable_app_annotation(args, app_model.id) + enable_args: EnableAnnotationArgs = { + "score_threshold": payload.score_threshold, + "embedding_provider_name": payload.embedding_provider_name, + "embedding_model_name": payload.embedding_model_name, + } + result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id) case "disable": result = AppAnnotationService.disable_app_annotation(app_model.id) return result, 200 @@ -135,8 +145,9 @@ class AnnotationListApi(Resource): @validate_app_token def post(self, app_model: App): """Create a new annotation.""" - args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() - annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) + payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) + insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer} + annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json"), HTTPStatus.CREATED @@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource): @edit_permission_required def put(self, app_model: App, annotation_id: str): """Update an existing annotation.""" - args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump() - annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) + payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) + update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer} + annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 6228cfc25b..e818573b8f 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -2,11 +2,10 @@ import logging from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError import services +from controllers.common.controller_schemas import TextToAudioPayload from controllers.common.schema import register_schema_model from controllers.service_api import service_api_ns from controllers.service_api.app.error import ( @@ -22,6 +21,7 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( @@ -86,13 +86,6 @@ class AudioApi(Resource): raise InternalServerError() -class TextToAudioPayload(BaseModel): - message_id: str | None = Field(default=None, description="Message ID") - voice: str | None = Field(default=None, description="Voice to use for TTS") - text: str | None = Field(default=None, description="Text to convert to audio") - streaming: bool | None = Field(default=None, description="Enable streaming response") - - register_schema_model(service_api_ns, TextToAudioPayload) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 3142e5118e..31f2797d66 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,7 +4,6 @@ from uuid import UUID from flask import request from flask_restx import Resource -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -29,6 +28,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 1ec289e2a2..ca4b18cb5e 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Any, Literal from flask import request @@ -14,14 +15,13 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db +from fields._value_type_serializer import serialize_value_type +from fields.base import ResponseModel from fields.conversation_fields import ( ConversationInfiniteScrollPagination, SimpleConversation, ) -from fields.conversation_variable_fields import ( - build_conversation_variable_infinite_scroll_pagination_model, - build_conversation_variable_model, -) +from graphon.variables.types import SegmentType from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService @@ -70,12 +70,70 @@ class ConversationVariableUpdatePayload(BaseModel): value: Any +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type()) + if isinstance(value, str): + try: + return str(SegmentType(value).exposed_type()) + except ValueError: + return value + try: + return serialize_value_type(value) + except (AttributeError, TypeError, ValueError): + pass + + try: + return serialize_value_type({"value_type": value}) + except (AttributeError, TypeError, ValueError): + value_attr = getattr(value, "value", None) + if value_attr is not None: + return str(value_attr) + return str(value) + + @field_validator("value", mode="before") + @classmethod + def normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[ConversationVariableResponse] + + register_schema_models( service_api_ns, ConversationListQuery, ConversationRenamePayload, ConversationVariablesQuery, ConversationVariableUpdatePayload, + ConversationVariableResponse, + ConversationVariableInfiniteScrollPaginationResponse, ) @@ -204,8 +262,12 @@ class ConversationVariablesApi(Resource): 404: "Conversation not found", } ) + @service_api_ns.response( + 200, + "Variables retrieved successfully", + service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__], + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - @service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns)) def get(self, app_model: App, end_user: EndUser, c_id): """List all variables for a conversation. @@ -222,9 +284,12 @@ class ConversationVariablesApi(Resource): last_id = str(query_args.last_id) if query_args.last_id else None try: - return ConversationService.get_conversational_variable( + pagination = ConversationService.get_conversational_variable( app_model, conversation_id, end_user, query_args.limit, last_id, query_args.variable_name ) + return ConversationVariableInfiniteScrollPaginationResponse.model_validate( + pagination, from_attributes=True + ).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -243,8 +308,12 @@ class ConversationVariableDetailApi(Resource): 404: "Conversation or variable not found", } ) + @service_api_ns.response( + 200, + "Variable updated successfully", + service_api_ns.models[ConversationVariableResponse.__name__], + ) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns)) def put(self, app_model: App, end_user: EndUser, c_id, variable_id): """Update a conversation variable's value. @@ -261,9 +330,10 @@ class ConversationVariableDetailApi(Resource): payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {}) try: - return ConversationService.update_conversation_variable( + variable = ConversationService.update_conversation_variable( app_model, conversation_id, variable_id, end_user, payload.value ) + return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json") except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.conversation.ConversationVariableNotExistsError: diff --git a/api/controllers/service_api/app/human_input_form.py b/api/controllers/service_api/app/human_input_form.py new file mode 100644 index 0000000000..8e5003dbbf --- /dev/null +++ b/api/controllers/service_api/app/human_input_form.py @@ -0,0 +1,137 @@ +""" +Service API human input form endpoints. + +This module exposes app-token authenticated APIs for fetching and submitting +paused human input forms in workflow/chatflow runs. +""" + +import json +import logging +from datetime import datetime + +from flask import Response +from flask_restx import Resource +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.common.human_input import HumanInputFormSubmitPayload +from controllers.common.schema import register_schema_models +from controllers.service_api import service_api_ns +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface +from extensions.ext_database import db +from models.model import App, EndUser +from services.human_input_service import Form, FormNotFoundError, HumanInputService + +logger = logging.getLogger(__name__) + + +register_schema_models(service_api_ns, HumanInputFormSubmitPayload) + + +def _stringify_default_values(values: dict[str, object]) -> dict[str, str]: + result: dict[str, str] = {} + for key, value in values.items(): + if value is None: + result[key] = "" + elif isinstance(value, (dict, list)): + result[key] = json.dumps(value, ensure_ascii=False) + else: + result[key] = str(value) + return result + + +def _to_timestamp(value: datetime) -> int: + return int(value.timestamp()) + + +def _jsonify_form_definition(form: Form) -> Response: + definition_payload = form.get_definition().model_dump() + payload = { + "form_content": definition_payload["rendered_content"], + "inputs": definition_payload["inputs"], + "resolved_default_values": _stringify_default_values(definition_payload["default_values"]), + "user_actions": definition_payload["user_actions"], + "expiration_time": _to_timestamp(form.expiration_time), + } + return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json") + + +def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None: + if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id: + raise NotFound("Form not found") + + +def _ensure_form_is_allowed_for_service_api(form: Form) -> None: + # Keep app-token callers scoped to the public web-form surface; internal HITL + # routes must continue to flow through console-only authentication. + if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API): + raise NotFound("Form not found") + + +@service_api_ns.route("/form/human_input/") +class WorkflowHumanInputFormApi(Resource): + @service_api_ns.doc("get_human_input_form") + @service_api_ns.doc(description="Get a paused human input form by token") + @service_api_ns.doc(params={"form_token": "Human input form token"}) + @service_api_ns.doc( + responses={ + 200: "Form retrieved successfully", + 401: "Unauthorized - invalid API token", + 404: "Form not found", + 412: "Form already submitted or expired", + } + ) + @validate_app_token + def get(self, app_model: App, form_token: str): + service = HumanInputService(db.engine) + form = service.get_form_by_token(form_token) + if form is None: + raise NotFound("Form not found") + + _ensure_form_belongs_to_app(form, app_model) + _ensure_form_is_allowed_for_service_api(form) + service.ensure_form_active(form) + return _jsonify_form_definition(form) + + @service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__]) + @service_api_ns.doc("submit_human_input_form") + @service_api_ns.doc(description="Submit a paused human input form by token") + @service_api_ns.doc(params={"form_token": "Human input form token"}) + @service_api_ns.doc( + responses={ + 200: "Form submitted successfully", + 400: "Bad request - invalid submission data", + 401: "Unauthorized - invalid API token", + 404: "Form not found", + 412: "Form already submitted or expired", + } + ) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser, form_token: str): + payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {}) + + service = HumanInputService(db.engine) + form = service.get_form_by_token(form_token) + if form is None: + raise NotFound("Form not found") + + _ensure_form_belongs_to_app(form, app_model) + _ensure_form_is_allowed_for_service_api(form) + + recipient_type = form.recipient_type + if recipient_type is None: + logger.warning("Recipient type is None for form, form_id=%s", form.id) + raise BadRequest("Form recipient type is invalid") + + try: + service.submit_form_by_token( + recipient_type=recipient_type, + form_token=form_token, + selected_action_id=payload.action, + form_data=payload.inputs, + submission_end_user_id=end_user.id, + ) + except FormNotFoundError: + raise NotFound("Form not found") + + return {}, 200 diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index e0a64ffe26..cc763fa89c 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,13 +1,12 @@ import logging +from collections.abc import Mapping +from datetime import datetime from typing import Literal from dateutil.parser import isoparse from flask import request -from flask_restx import Namespace, Resource, fields -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field +from flask_restx import Resource, fields +from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -33,9 +32,13 @@ from core.errors.error import ( from core.helper.trace_id_helper import get_external_trace_id from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.helper import OptionalTimestampField, TimestampField from models.model import App, AppMode, EndUser from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory @@ -65,38 +68,142 @@ class WorkflowLogQuery(BaseModel): register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery) +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +def _enum_value(value): + return getattr(value, "value", value) + + class WorkflowRunStatusField(fields.Raw): def output(self, key, obj: WorkflowRun, **kwargs): - return obj.status.value + return _enum_value(obj.status) class WorkflowRunOutputsField(fields.Raw): def output(self, key, obj: WorkflowRun, **kwargs): - if obj.status == WorkflowExecutionStatus.PAUSED: + status = _enum_value(obj.status) + if status == WorkflowExecutionStatus.PAUSED.value: return {} outputs = obj.outputs_dict return outputs or {} -workflow_run_fields = { - "id": fields.String, - "workflow_id": fields.String, - "status": WorkflowRunStatusField, - "inputs": fields.Raw, - "outputs": WorkflowRunOutputsField, - "error": fields.String, - "total_steps": fields.Integer, - "total_tokens": fields.Integer, - "created_at": TimestampField, - "finished_at": OptionalTimestampField, - "elapsed_time": fields.Float, -} +class WorkflowRunResponse(ResponseModel): + id: str + workflow_id: str + status: str + inputs: dict | list | str | int | float | bool | None = None + outputs: dict = Field(default_factory=dict) + error: str | None = None + total_steps: int | None = None + total_tokens: int | None = None + created_at: int | None = None + finished_at: int | None = None + elapsed_time: float | int | None = None + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) -def build_workflow_run_model(api_or_ns: Namespace): - """Build the workflow run model for the API or Namespace.""" - return api_or_ns.model("WorkflowRun", workflow_run_fields) +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | int | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", "triggered_from", mode="before") + @classmethod + def _normalize_enum(cls, value): + return _enum_value(value) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: dict | list | str | int | float | bool | None = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_from", "created_by_role", mode="before") + @classmethod + def _normalize_enum(cls, value): + return _enum_value(value) + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +register_schema_models( + service_api_ns, + WorkflowRunResponse, + WorkflowRunForLogResponse, + WorkflowAppLogPartialResponse, + WorkflowAppLogPaginationResponse, +) + + +def _serialize_workflow_run(workflow_run: WorkflowRun) -> dict: + status = _enum_value(workflow_run.status) + raw_outputs = workflow_run.outputs_dict + if status == WorkflowExecutionStatus.PAUSED.value or raw_outputs is None: + outputs: dict = {} + elif isinstance(raw_outputs, dict): + outputs = raw_outputs + elif isinstance(raw_outputs, Mapping): + outputs = dict(raw_outputs) + else: + outputs = {} + return WorkflowRunResponse.model_validate( + { + "id": workflow_run.id, + "workflow_id": workflow_run.workflow_id, + "status": status, + "inputs": workflow_run.inputs, + "outputs": outputs, + "error": workflow_run.error, + "total_steps": workflow_run.total_steps, + "total_tokens": workflow_run.total_tokens, + "created_at": workflow_run.created_at, + "finished_at": workflow_run.finished_at, + "elapsed_time": workflow_run.elapsed_time, + } + ).model_dump(mode="json") + + +def _serialize_workflow_log_pagination(pagination) -> dict: + return WorkflowAppLogPaginationResponse.model_validate(pagination, from_attributes=True).model_dump(mode="json") @service_api_ns.route("/workflows/run/") @@ -112,7 +219,11 @@ class WorkflowRunDetailApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) + @service_api_ns.response( + 200, + "Workflow run details retrieved successfully", + service_api_ns.models[WorkflowRunResponse.__name__], + ) def get(self, app_model: App, workflow_run_id: str): """Get a workflow task running detail. @@ -133,7 +244,7 @@ class WorkflowRunDetailApi(Resource): ) if not workflow_run: raise NotFound("Workflow run not found.") - return workflow_run + return _serialize_workflow_run(workflow_run) @service_api_ns.route("/workflows/run") @@ -299,7 +410,11 @@ class WorkflowAppLogApi(Resource): } ) @validate_app_token - @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns)) + @service_api_ns.response( + 200, + "Logs retrieved successfully", + service_api_ns.models[WorkflowAppLogPaginationResponse.__name__], + ) def get(self, app_model: App): """Get workflow app logs. @@ -327,4 +442,4 @@ class WorkflowAppLogApi(Resource): created_by_account=args.created_by_account, ) - return workflow_app_log_pagination + return _serialize_workflow_log_pagination(workflow_app_log_pagination) diff --git a/api/controllers/service_api/app/workflow_events.py b/api/controllers/service_api/app/workflow_events.py new file mode 100644 index 0000000000..b281b271c0 --- /dev/null +++ b/api/controllers/service_api/app/workflow_events.py @@ -0,0 +1,142 @@ +""" +Service API workflow resume event stream endpoints. +""" + +import json +from collections.abc import Generator + +from flask import Response, request +from flask_restx import Resource +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +from controllers.service_api import service_api_ns +from controllers.service_api.app.error import NotWorkflowAppError +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_generator import BaseAppGenerator +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.message_generator import MessageGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator +from core.app.entities.task_entities import StreamEvent +from core.workflow.human_input_policy import HumanInputSurface +from extensions.ext_database import db +from models.enums import CreatorUserRole +from models.model import App, AppMode, EndUser +from repositories.factory import DifyAPIRepositoryFactory +from services.workflow_event_snapshot_service import build_workflow_event_stream + + +@service_api_ns.route("/workflow//events") +class WorkflowEventsApi(Resource): + """Service API for getting workflow execution events after resume.""" + + @service_api_ns.doc("get_workflow_events") + @service_api_ns.doc(description="Get workflow execution events stream after resume") + @service_api_ns.doc( + params={ + "task_id": "Workflow run ID", + "user": "End user identifier (query param)", + "include_state_snapshot": ( + "Whether to replay from persisted state snapshot, " + 'specify `"true"` to include a status snapshot of executed nodes' + ), + "continue_on_pause": ( + "Whether to keep the stream open across workflow_paused events," + 'specify `"true"` to keep the stream open for `workflow_paused` events.' + ), + } + ) + @service_api_ns.doc( + responses={ + 200: "SSE event stream", + 401: "Unauthorized - invalid API token", + 404: "Workflow run not found", + } + ) + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) + def get(self, app_model: App, end_user: EndUser, task_id: str): + app_mode = AppMode.value_of(app_model.mode) + if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}: + raise NotWorkflowAppError() + + session_maker = sessionmaker(db.engine) + repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + workflow_run = repo.get_workflow_run_by_id_and_tenant_id( + tenant_id=app_model.tenant_id, + run_id=task_id, + ) + + if workflow_run is None: + raise NotFound("Workflow run not found") + + if workflow_run.app_id != app_model.id: + raise NotFound("Workflow run not found") + + if workflow_run.created_by_role != CreatorUserRole.END_USER: + raise NotFound("Workflow run not found") + + if workflow_run.created_by != end_user.id: + raise NotFound("Workflow run not found") + + workflow_run_entity = workflow_run + + if workflow_run_entity.finished_at is not None: + response = WorkflowResponseConverter.workflow_run_result_to_finish_response( + task_id=workflow_run_entity.id, + workflow_run=workflow_run_entity, + creator_user=end_user, + ) + + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + + def _generate_finished_events() -> Generator[str, None, None]: + yield f"data: {json.dumps(payload)}\n\n" + + event_generator = _generate_finished_events + else: + msg_generator = MessageGenerator() + generator: BaseAppGenerator + if app_mode == AppMode.ADVANCED_CHAT: + generator = AdvancedChatAppGenerator() + elif app_mode == AppMode.WORKFLOW: + generator = WorkflowAppGenerator() + else: + raise NotWorkflowAppError() + + include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true" + continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true" + terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None + + def _generate_stream_events(): + if include_state_snapshot: + return generator.convert_to_event_stream( + build_workflow_event_stream( + app_mode=app_mode, + workflow_run=workflow_run_entity, + tenant_id=app_model.tenant_id, + app_id=app_model.id, + session_maker=session_maker, + human_input_surface=HumanInputSurface.SERVICE_API, + close_on_pause=not continue_on_pause, + ) + ) + return generator.convert_to_event_stream( + msg_generator.retrieve_events( + app_mode, + workflow_run_entity.id, + terminal_events=terminal_events, + ), + ) + + event_generator = _generate_stream_events + + return Response( + event_generator(), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index fd954be6b1..3eb773fa7c 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,8 +2,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal -from graphon.model_runtime.entities.model_entities import ModelType -from pydantic import BaseModel, Field, TypeAdapter, field_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from werkzeug.exceptions import Forbidden, NotFound import services @@ -19,6 +18,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage from core.rag.index_processor.constant.index_type import IndexTechniqueType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum @@ -100,9 +100,27 @@ class TagBindingPayload(BaseModel): class TagUnbindingPayload(BaseModel): - tag_id: str + """Accept the legacy single-tag Service API payload while exposing a normalized tag_ids list internally.""" + + tag_ids: list[str] = Field(default_factory=list) + tag_id: str | None = None target_id: str + @model_validator(mode="before") + @classmethod + def normalize_legacy_tag_id(cls, data: object) -> object: + if not isinstance(data, dict): + return data + if not data.get("tag_ids") and data.get("tag_id"): + return {**data, "tag_ids": [data["tag_id"]]} + return data + + @model_validator(mode="after") + def validate_tag_ids(self) -> "TagUnbindingPayload": + if not self.tag_ids: + raise ValueError("Tag IDs is required.") + return self + class DatasetListQuery(BaseModel): page: int = Field(default=1, description="Page number") @@ -601,11 +619,11 @@ class DatasetTagBindingApi(DatasetApiResource): @service_api_ns.route("/datasets/tags/unbinding") class DatasetTagUnbindingApi(DatasetApiResource): @service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__]) - @service_api_ns.doc("unbind_dataset_tag") - @service_api_ns.doc(description="Unbind a tag from a dataset") + @service_api_ns.doc("unbind_dataset_tags") + @service_api_ns.doc(description="Unbind tags from a dataset") @service_api_ns.doc( responses={ - 204: "Tag unbound successfully", + 204: "Tags unbound successfully", 401: "Unauthorized - invalid API token", 403: "Forbidden - insufficient permissions", } @@ -618,7 +636,7 @@ class DatasetTagUnbindingApi(DatasetApiResource): payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) TagService.delete_tag_binding( - TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE) + TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE) ) return "", 204 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 9f1ce17ed9..0b09facf58 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,4 +1,12 @@ +"""Service API endpoints for dataset document management. + +The canonical Service API paths use hyphenated route segments. Legacy underscore +aliases remain registered for backward compatibility, but they must stay marked +deprecated in generated API docs so clients migrate toward the canonical paths. +""" + import json +from collections.abc import Mapping from contextlib import ExitStack from typing import Self from uuid import UUID @@ -10,6 +18,7 @@ from sqlalchemy import desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services +from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload from controllers.common.errors import ( FilenameNotExistsError, FileTooLargeError, @@ -100,15 +109,6 @@ class DocumentListQuery(BaseModel): status: str | None = Field(default=None, description="Document status filter") -DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100 - - -class DocumentBatchDownloadZipPayload(BaseModel): - """Request payload for bulk downloading uploaded documents as a ZIP archive.""" - - document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS) - - register_enum_models(service_api_ns, RetrievalMethod) register_schema_models( @@ -125,12 +125,137 @@ register_schema_models( ) -@service_api_ns.route( - "/datasets//document/create_by_text", - "/datasets//document/create-by-text", -) +def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]: + """Create a document from text for both canonical and legacy routes.""" + payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {}) + args = payload.model_dump(exclude_none=True) + + dataset_id_str = str(dataset_id) + tenant_id_str = str(tenant_id) + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1) + ) + + if not dataset: + raise ValueError("Dataset does not exist.") + + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") + + embedding_model_provider = payload.embedding_model_provider + embedding_model = payload.embedding_model + if embedding_model_provider and embedding_model: + DatasetService.check_embedding_model_setting(tenant_id_str, embedding_model_provider, embedding_model) + + retrieval_model = payload.retrieval_model + if ( + retrieval_model + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name + ): + DatasetService.check_reranking_model_setting( + tenant_id_str, + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, + ) + + if not current_user: + raise ValueError("current_user is required") + + upload_file = FileService(db.engine).upload_text( + text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id_str + ) + data_source = { + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + } + args["data_source"] = data_source + knowledge_config = KnowledgeConfig.model_validate(args) + DocumentService.document_create_args_validate(knowledge_config) + + if not current_user: + raise ValueError("current_user is required") + + try: + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + knowledge_config=knowledge_config, + account=current_user, + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} + return documents_and_batch_fields, 200 + + +def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]: + """Update a document from text for both canonical and legacy routes.""" + payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + ) + args = payload.model_dump(exclude_none=True) + if not dataset: + raise ValueError("Dataset does not exist.") + + retrieval_model = payload.retrieval_model + if ( + retrieval_model + and retrieval_model.reranking_model + and retrieval_model.reranking_model.reranking_provider_name + and retrieval_model.reranking_model.reranking_model_name + ): + DatasetService.check_reranking_model_setting( + tenant_id, + retrieval_model.reranking_model.reranking_provider_name, + retrieval_model.reranking_model.reranking_model_name, + ) + + # indexing_technique is already set in dataset since this is an update + args["indexing_technique"] = dataset.indexing_technique + + if args.get("text"): + text = args.get("text") + name = args.get("name") + if not current_user: + raise ValueError("current_user is required") + upload_file = FileService(db.engine).upload_text( + text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id + ) + data_source = { + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + } + args["data_source"] = data_source + + args["original_document_id"] = str(document_id) + knowledge_config = KnowledgeConfig.model_validate(args) + DocumentService.document_create_args_validate(knowledge_config) + + try: + documents, batch = DocumentService.save_document_with_dataset_id( + dataset=dataset, + knowledge_config=knowledge_config, + account=current_user, + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} + return documents_and_batch_fields, 200 + + +@service_api_ns.route("/datasets//document/create-by-text") class DocumentAddByTextApi(DatasetApiResource): - """Resource for documents.""" + """Resource for the canonical text document creation route.""" @service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__]) @service_api_ns.doc("create_document_by_text") @@ -146,81 +271,43 @@ class DocumentAddByTextApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id): + def post(self, tenant_id: str, dataset_id: UUID): """Create document by text.""" - payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {}) - args = payload.model_dump(exclude_none=True) + return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id) - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + +@service_api_ns.route("/datasets//document/create_by_text") +class DeprecatedDocumentAddByTextApi(DatasetApiResource): + """Deprecated resource alias for text document creation.""" + + @service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__]) + @service_api_ns.doc("create_document_by_text_deprecated") + @service_api_ns.doc(deprecated=True) + @service_api_ns.doc( + description=( + "Deprecated legacy alias for creating a new document by providing text content. " + "Use /datasets/{dataset_id}/document/create-by-text instead." ) - - if not dataset: - raise ValueError("Dataset does not exist.") - - if not dataset.indexing_technique and not args["indexing_technique"]: - raise ValueError("indexing_technique is required.") - - embedding_model_provider = payload.embedding_model_provider - embedding_model = payload.embedding_model - if embedding_model_provider and embedding_model: - DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) - - retrieval_model = payload.retrieval_model - if ( - retrieval_model - and retrieval_model.reranking_model - and retrieval_model.reranking_model.reranking_provider_name - and retrieval_model.reranking_model.reranking_model_name - ): - DatasetService.check_reranking_model_setting( - tenant_id, - retrieval_model.reranking_model.reranking_provider_name, - retrieval_model.reranking_model.reranking_model_name, - ) - - if not current_user: - raise ValueError("current_user is required") - - upload_file = FileService(db.engine).upload_text( - text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id - ) - data_source = { - "type": "upload_file", - "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + ) + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "Document created successfully", + 401: "Unauthorized - invalid API token", + 400: "Bad request - invalid parameters", } - args["data_source"] = data_source - knowledge_config = KnowledgeConfig.model_validate(args) - # validate args - DocumentService.document_create_args_validate(knowledge_config) - - if not current_user: - raise ValueError("current_user is required") - - try: - documents, batch = DocumentService.save_document_with_dataset_id( - dataset=dataset, - knowledge_config=knowledge_config, - account=current_user, - dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, - created_from="api", - ) - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - document = documents[0] - - documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} - return documents_and_batch_fields, 200 + ) + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def post(self, tenant_id: str, dataset_id: UUID): + """Create document by text through the deprecated underscore alias.""" + return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id) -@service_api_ns.route( - "/datasets//documents//update_by_text", - "/datasets//documents//update-by-text", -) +@service_api_ns.route("/datasets//documents//update-by-text") class DocumentUpdateByTextApi(DatasetApiResource): - """Resource for update documents.""" + """Resource for the canonical text document update route.""" @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__]) @service_api_ns.doc("update_document_by_text") @@ -237,62 +324,35 @@ class DocumentUpdateByTextApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" - payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) - dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id) + + +@service_api_ns.route("/datasets//documents//update_by_text") +class DeprecatedDocumentUpdateByTextApi(DatasetApiResource): + """Deprecated resource alias for text document updates.""" + + @service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__]) + @service_api_ns.doc("update_document_by_text_deprecated") + @service_api_ns.doc(deprecated=True) + @service_api_ns.doc( + description=( + "Deprecated legacy alias for updating an existing document by providing text content. " + "Use /datasets/{dataset_id}/documents/{document_id}/update-by-text instead." ) - args = payload.model_dump(exclude_none=True) - if not dataset: - raise ValueError("Dataset does not exist.") - - retrieval_model = payload.retrieval_model - if ( - retrieval_model - and retrieval_model.reranking_model - and retrieval_model.reranking_model.reranking_provider_name - and retrieval_model.reranking_model.reranking_model_name - ): - DatasetService.check_reranking_model_setting( - tenant_id, - retrieval_model.reranking_model.reranking_provider_name, - retrieval_model.reranking_model.reranking_model_name, - ) - - # indexing_technique is already set in dataset since this is an update - args["indexing_technique"] = dataset.indexing_technique - - if args.get("text"): - text = args.get("text") - name = args.get("name") - if not current_user: - raise ValueError("current_user is required") - upload_file = FileService(db.engine).upload_text( - text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id - ) - data_source = { - "type": "upload_file", - "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, - } - args["data_source"] = data_source - # validate args - args["original_document_id"] = str(document_id) - knowledge_config = KnowledgeConfig.model_validate(args) - DocumentService.document_create_args_validate(knowledge_config) - - try: - documents, batch = DocumentService.save_document_with_dataset_id( - dataset=dataset, - knowledge_config=knowledge_config, - account=current_user, - dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, - created_from="api", - ) - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - document = documents[0] - - documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} - return documents_and_batch_fields, 200 + ) + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): + """Update document by text through the deprecated underscore alias.""" + return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id) @service_api_ns.route( @@ -408,15 +468,98 @@ class DocumentAddByFileApi(DatasetApiResource): return documents_and_batch_fields, 200 +def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]: + """Update a document from an uploaded file for canonical and deprecated routes.""" + dataset_id_str = str(dataset_id) + tenant_id_str = str(tenant_id) + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1) + ) + + if not dataset: + raise ValueError("Dataset does not exist.") + + if dataset.provider == "external": + raise ValueError("External datasets are not supported.") + + args: dict[str, object] = {} + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = dataset.chunk_structure or "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" + + # indexing_technique is already set in dataset since this is an update + args["indexing_technique"] = dataset.indexing_technique + + if "file" in request.files: + # save file info + file = request.files["file"] + + if len(request.files) > 1: + raise TooManyFilesError() + + if not file.filename: + raise FilenameNotExistsError + + if not current_user: + raise ValueError("current_user is required") + + try: + upload_file = FileService(db.engine).upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + data_source = { + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, + } + args["data_source"] = data_source + + # validate args + args["original_document_id"] = str(document_id) + + knowledge_config = KnowledgeConfig.model_validate(args) + DocumentService.document_create_args_validate(knowledge_config) + + try: + documents, _ = DocumentService.save_document_with_dataset_id( + dataset=dataset, + knowledge_config=knowledge_config, + account=dataset.created_by_account, + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + document = documents[0] + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch} + return documents_and_batch_fields, 200 + + @service_api_ns.route( "/datasets//documents//update_by_file", "/datasets//documents//update-by-file", ) -class DocumentUpdateByFileApi(DatasetApiResource): - """Resource for update documents.""" +class DeprecatedDocumentUpdateByFileApi(DatasetApiResource): + """Deprecated resource aliases for file document updates.""" - @service_api_ns.doc("update_document_by_file") - @service_api_ns.doc(description="Update an existing document by uploading a file") + @service_api_ns.doc("update_document_by_file_deprecated") + @service_api_ns.doc(deprecated=True) + @service_api_ns.doc( + description=( + "Deprecated legacy alias for updating an existing document by uploading a file. " + "Use PATCH /datasets/{dataset_id}/documents/{document_id} instead." + ) + ) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @service_api_ns.doc( responses={ @@ -427,82 +570,9 @@ class DocumentUpdateByFileApi(DatasetApiResource): ) @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") - def post(self, tenant_id, dataset_id, document_id): - """Update document by upload file.""" - dataset = db.session.scalar( - select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) - ) - - if not dataset: - raise ValueError("Dataset does not exist.") - - if dataset.provider == "external": - raise ValueError("External datasets are not supported.") - - args = {} - if "data" in request.form: - args = json.loads(request.form["data"]) - if "doc_form" not in args: - args["doc_form"] = dataset.chunk_structure or "text_model" - if "doc_language" not in args: - args["doc_language"] = "English" - - # get dataset info - dataset_id = str(dataset_id) - tenant_id = str(tenant_id) - - # indexing_technique is already set in dataset since this is an update - args["indexing_technique"] = dataset.indexing_technique - - if "file" in request.files: - # save file info - file = request.files["file"] - - if len(request.files) > 1: - raise TooManyFilesError() - - if not file.filename: - raise FilenameNotExistsError - - if not current_user: - raise ValueError("current_user is required") - - try: - upload_file = FileService(db.engine).upload_file( - filename=file.filename, - content=file.read(), - mimetype=file.mimetype, - user=current_user, - source="datasets", - ) - except services.errors.file.FileTooLargeError as file_too_large_error: - raise FileTooLargeError(file_too_large_error.description) - except services.errors.file.UnsupportedFileTypeError: - raise UnsupportedFileTypeError() - data_source = { - "type": "upload_file", - "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, - } - args["data_source"] = data_source - # validate args - args["original_document_id"] = str(document_id) - - knowledge_config = KnowledgeConfig.model_validate(args) - DocumentService.document_create_args_validate(knowledge_config) - - try: - documents, _ = DocumentService.save_document_with_dataset_id( - dataset=dataset, - knowledge_config=knowledge_config, - account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, - created_from="api", - ) - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - document = documents[0] - documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch} - return documents_and_batch_fields, 200 + def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): + """Update document by file through the deprecated file-update aliases.""" + return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id) @service_api_ns.route("/datasets//documents") @@ -527,7 +597,7 @@ class DocumentListApi(DatasetApiResource): if not dataset: raise NotFound("Dataset not found.") - query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) + query = select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == tenant_id) if query_params.status: query = DocumentService.apply_display_status_filter(query, query_params.status) @@ -816,6 +886,22 @@ class DocumentApi(DatasetApiResource): return response + @service_api_ns.doc("update_document_by_file") + @service_api_ns.doc(description="Update an existing document by uploading a file") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Document updated successfully", + 401: "Unauthorized - invalid API token", + 404: "Document not found", + } + ) + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID): + """Update document by file on the canonical document resource.""" + return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id) + @service_api_ns.doc("delete_document") @service_api_ns.doc(description="Delete a document") @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 52166f7fcc..21db7d0cb8 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -2,9 +2,9 @@ from typing import Literal from flask_login import current_user from flask_restx import marshal -from pydantic import BaseModel from werkzeug.exceptions import NotFound +from controllers.common.controller_schemas import MetadataUpdatePayload from controllers.common.schema import register_schema_model, register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check @@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import ( ) from services.metadata_service import MetadataService - -class MetadataUpdatePayload(BaseModel): - name: str - - register_schema_model(service_api_ns, MetadataUpdatePayload) register_schema_models( service_api_ns, diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 5b16da81e0..5992fa7410 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -2,12 +2,12 @@ from typing import Any from flask import request from flask_restx import marshal -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config +from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload from controllers.common.schema import register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.app.error import ProviderNotInitializeError @@ -22,6 +22,7 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields +from graphon.model_runtime.entities.model_entities import ModelType from libs.login import current_account_with_tenant from models.dataset import Dataset from services.dataset_service import DatasetService, DocumentService, SegmentService @@ -32,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS from services.summary_index_service import SummaryIndexService -def _marshal_segment_with_summary(segment, dataset_id: str) -> dict: +def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]: """Marshal a single segment and enrich it with summary content.""" - segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] + segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) segment_dict["summary"] = summary.summary_content if summary else None return segment_dict -def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]: +def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]: """Marshal multiple segments and enrich them with summary content (batch query).""" segment_ids = [segment.id for segment in segments] - summaries: dict = {} + summaries: dict[str, str | None] = {} if segment_ids: summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id) summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()} - result = [] + result: list[dict[str, Any]] = [] for segment in segments: - segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] + segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type] segment_dict["summary"] = summaries.get(segment.id) result.append(segment_dict) return result @@ -69,20 +70,12 @@ class SegmentUpdatePayload(BaseModel): segment: SegmentUpdateArgs -class ChildChunkCreatePayload(BaseModel): - content: str - - class ChildChunkListQuery(BaseModel): limit: int = Field(default=20, ge=1) keyword: str | None = None page: int = Field(default=1, ge=1) -class ChildChunkUpdatePayload(BaseModel): - content: str - - register_schema_models( service_api_ns, SegmentCreatePayload, diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index c0a6cb0a76..5ac65fc4e6 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,9 +1,9 @@ from flask_login import current_user from flask_restx import Resource -from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token +from graphon.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 0ef4471018..8ddbc3abb8 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -2,7 +2,6 @@ import logging from flask import request from flask_restx import fields, marshal_with -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import field_validator from werkzeug.exceptions import InternalServerError @@ -22,8 +21,9 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value -from models.model import App +from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -69,12 +69,12 @@ class AudioApi(WebApiResource): 500: "Internal Server Error", } ) - def post(self, app_model: App, end_user): + def post(self, app_model: App, end_user: EndUser): """Convert audio to text""" file = request.files["file"] try: - response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.external_user_id) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -117,7 +117,7 @@ class TextApi(WebApiResource): 500: "Internal Server Error", } ) - def post(self, app_model: App, end_user): + def post(self, app_model: App, end_user: EndUser): """Convert text to audio""" try: payload = TextToAudioPayload.model_validate(web_ns.payload or {}) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index e37f9af5f0..0528184d79 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,7 +1,6 @@ import logging from typing import Any, Literal -from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,6 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index 80c3289fb4..61fd794c22 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -3,7 +3,6 @@ import secrets from flask import request from flask_restx import Resource -from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console.auth.error import ( @@ -62,9 +61,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session) - token = None + account = AccountService.get_account_by_email_with_case_fallback(request_email) if account is None: raise AuthenticationFailedError() else: @@ -161,13 +158,14 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") - with sessionmaker(db.engine).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) - if account: - self._update_existing_account(account, password_hashed, salt) - else: - raise AuthenticationFailedError() + if account: + account = db.session.merge(account) + self._update_existing_account(account, password_hashed, salt) + db.session.commit() + else: + raise AuthenticationFailedError() return {"result": "success"} diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index aff0b42d95..1ddf2e0717 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -5,14 +5,15 @@ Web App Human Input Form APIs. import json import logging from datetime import datetime +from typing import Any, NotRequired, TypedDict from flask import Response, request from flask_restx import Resource -from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden from configs import dify_config +from controllers.common.human_input import HumanInputFormSubmitPayload from controllers.web import web_ns from controllers.web.error import NotFoundError, WebFormRateLimitExceededError from controllers.web.site import serialize_app_site_payload @@ -25,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ logger = logging.getLogger(__name__) -class HumanInputFormSubmitPayload(BaseModel): - inputs: dict - action: str - - _FORM_SUBMIT_RATE_LIMITER = RateLimiter( prefix="web_form_submit_rate_limit", max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, @@ -58,10 +54,19 @@ def _to_timestamp(value: datetime) -> int: return int(value.timestamp()) +class FormDefinitionPayload(TypedDict): + form_content: Any + inputs: Any + resolved_default_values: dict[str, str] + user_actions: Any + expiration_time: int + site: NotRequired[dict] + + def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response: """Return the form payload (optionally with site) as a JSON response.""" definition_payload = form.get_definition().model_dump() - payload = { + payload: FormDefinitionPayload = { "form_content": definition_payload["rendered_content"], "inputs": definition_payload["inputs"], "resolved_default_values": _stringify_default_values(definition_payload["default_values"]), @@ -92,7 +97,7 @@ class HumanInputFormApi(Resource): _FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address) service = HumanInputService(db.engine) - # TODO(QuantumGhost): forbid submision for form tokens + # TODO(QuantumGhost): forbid submission for form tokens # that are only for console. form = service.get_form_by_token(form_token) diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index ae0e6789ef..2255dd0332 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -1,7 +1,10 @@ +import logging + from flask import make_response, request from flask_restx import Resource from jwt import InvalidTokenError from pydantic import BaseModel, Field, field_validator +from werkzeug.exceptions import Unauthorized import services from configs import dify_config @@ -20,7 +23,7 @@ from controllers.console.wraps import ( ) from controllers.web import web_ns from controllers.web.wraps import decode_jwt_token -from libs.helper import EmailStr +from libs.helper import EmailStr, extract_remote_ip from libs.passport import PassportService from libs.password import valid_password from libs.token import ( @@ -29,9 +32,11 @@ from libs.token import ( ) from services.account_service import AccountService from services.app_service import AppService -from services.entities.auth_entities import LoginPayloadBase +from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase from services.webapp_auth_service import WebAppAuthService +logger = logging.getLogger(__name__) + class LoginPayload(LoginPayloadBase): @field_validator("password") @@ -76,14 +81,18 @@ class LoginApi(Resource): def post(self): """Authenticate user and login.""" payload = LoginPayload.model_validate(web_ns.payload or {}) + normalized_email = payload.email.lower() try: account = WebAppAuthService.authenticate(payload.email, payload.password) except services.errors.account.AccountLoginError: + _log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED) raise AccountBannedError() except services.errors.account.AccountPasswordError: + _log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS) raise AuthenticationFailedError() except services.errors.account.AccountNotFoundError: + _log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND) raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) @@ -212,21 +221,30 @@ class EmailCodeLoginApi(Resource): token_data = WebAppAuthService.get_email_code_login_data(payload.token) if token_data is None: + _log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN) raise InvalidTokenError() token_email = token_data.get("email") if not isinstance(token_email, str): + _log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH) raise InvalidEmailError() normalized_token_email = token_email.lower() if normalized_token_email != user_email: + _log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH) raise InvalidEmailError() if token_data["code"] != payload.code: + _log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE) raise EmailCodeError() WebAppAuthService.revoke_email_code_login_token(payload.token) - account = WebAppAuthService.get_user_through_email(token_email) + try: + account = WebAppAuthService.get_user_through_email(token_email) + except Unauthorized as exc: + _log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED) + raise AccountBannedError() from exc if not account: + _log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND) raise AuthenticationFailedError() token = WebAppAuthService.login(account=account) @@ -234,3 +252,12 @@ class EmailCodeLoginApi(Resource): response = make_response({"result": "success", "data": {"access_token": token}}) # set_access_token_to_cookie(request, response, token, samesite="None", httponly=False) return response + + +def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None: + logger.warning( + "Web login failed: email=%s reason=%s ip_address=%s", + email, + reason, + extract_remote_ip(request), + ) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 25cb6b2b9e..07ecf8035b 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,11 +2,10 @@ import logging from typing import Literal from flask import request -from graphon.model_runtime.errors.invoke import InvokeError -from pydantic import BaseModel, Field, TypeAdapter, field_validator +from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound -from controllers.common.controller_schemas import MessageFeedbackPayload +from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery from controllers.common.schema import register_schema_models from controllers.web import web_ns from controllers.web.error import ( @@ -24,8 +23,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.helper import uuid_value from models.enums import FeedbackRating from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -41,19 +40,6 @@ from services.message_service import MessageService logger = logging.getLogger(__name__) -class MessageListQuery(BaseModel): - conversation_id: str = Field(description="Conversation UUID") - first_id: str | None = Field(default=None, description="First message ID for pagination") - limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)") - - @field_validator("conversation_id", "first_id") - @classmethod - def validate_uuid(cls, value: str | None) -> str | None: - if value is None: - return value - return uuid_value(value) - - class MessageMoreLikeThisQuery(BaseModel): response_mode: Literal["blocking", "streaming"] = Field( description="Response mode", diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 66082893b8..0293df74b0 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,5 +1,6 @@ import uuid from datetime import UTC, datetime, timedelta +from typing import Any from flask import make_response, request from flask_restx import Resource @@ -103,21 +104,23 @@ class PassportResource(Resource): return response -def decode_enterprise_webapp_user_id(jwt_token: str | None): +def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None: """ Decode the enterprise user session from the Authorization header. """ if not jwt_token: return None - decoded = PassportService().verify(jwt_token) + decoded: dict[str, Any] = PassportService().verify(jwt_token) source = decoded.get("token_source") if not source or source != "webapp_login_token": raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.") return decoded -def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType): +def exchange_token_for_existing_web_user( + app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType +): """ Exchange a token for an existing web user session. """ diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 38aeccc642..fe31e9d4ac 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,6 @@ import urllib.parse import httpx -from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, HttpUrl import services @@ -14,6 +13,7 @@ from controllers.common.errors import ( from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo +from graphon.file import helpers as file_helpers from services.file_service import FileService from ..common.schema import register_schema_models diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 1a0c6d4252..7d2080dd91 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from flask_restx import fields, marshal, marshal_with from sqlalchemy import select @@ -113,12 +113,12 @@ class AppSiteInfo: } -def serialize_site(site: Site) -> dict: +def serialize_site(site: Site) -> dict[str, Any]: """Serialize Site model using the same schema as AppSiteApi.""" - return cast(dict, marshal(site, AppSiteApi.site_fields)) + return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields)) -def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict: +def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]: can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo) - return cast(dict, marshal(app_site_info, AppSiteApi.app_fields)) + return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields)) diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 796e090976..98211193a0 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,7 +1,5 @@ import logging -from graphon.graph_engine.manager import GraphEngineManager -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError from controllers.common.controller_schemas import WorkflowRunPayload @@ -24,6 +22,8 @@ from core.errors.error import ( QuotaExceededError, ) from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 06c746990d..c22102c2ba 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,20 +4,6 @@ import uuid from decimal import Decimal from typing import Union, cast -from graphon.file import file_manager -from graphon.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from graphon.model_runtime.entities.model_entities import ModelFeature -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity @@ -43,6 +29,20 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db from factories import file_factory +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 11e2aa062d..0bc93ad34d 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -2,16 +2,7 @@ import json import logging from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import Any - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) +from typing import Any, TypedDict from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit @@ -24,11 +15,26 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) from models.model import Message logger = logging.getLogger(__name__) +class ActionDict(TypedDict): + """Shape produced by AgentScratchpadUnit.Action.to_dict().""" + + action: str + action_input: dict[str, Any] | str + + class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True _ignore_observation_providers = ["wenxin"] @@ -331,7 +337,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): return tool_invoke_response, tool_invoke_meta - def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: + def _convert_dict_to_action(self, action: ActionDict) -> AgentScratchpadUnit.Action: """ convert dict to action """ diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 2b2e26987e..a2186be100 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,5 +1,6 @@ import json +from core.agent.cot_agent_runner import CotAgentRunner from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -11,8 +12,6 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.model_runtime.utils.encoders import jsonable_encoder -from core.agent.cot_agent_runner import CotAgentRunner - class CotChatAgentRunner(CotAgentRunner): def _organize_system_prompt(self) -> SystemPromptMessage: diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index d4c52a8eb1..51a30998ae 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,5 +1,6 @@ import json +from core.agent.cot_agent_runner import CotAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -8,8 +9,6 @@ from graphon.model_runtime.entities.message_entities import ( ) from graphon.model_runtime.utils.encoders import jsonable_encoder -from core.agent.cot_agent_runner import CotAgentRunner - class CotCompletionAgentRunner(CotAgentRunner): def _organize_instruction_prompt(self) -> str: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index fdffde85d0..29de0b8b1c 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,6 +4,13 @@ from collections.abc import Generator from copy import deepcopy from typing import Any, Union +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -19,14 +26,6 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes - -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.errors import AgentMaxIterationError -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine from models.model import Message logger = logging.getLogger(__name__) @@ -300,7 +299,9 @@ class FunctionCallAgentRunner(BaseAgentRunner): # update prompt tool for prompt_tool in prompt_messages_tools: - self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) + tool_instance = tool_instances.get(prompt_tool.name) + if tool_instance: + self.update_prompt_message_tool(tool_instance, prompt_tool) iteration_step += 1 diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 46c1f1230d..f341ca5a0b 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -1,17 +1,16 @@ import json import re from collections.abc import Generator -from typing import Union - -from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from typing import Any, Union from core.agent.entities import AgentScratchpadUnit +from graphon.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: @classmethod def handle_react_stream_output( - cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict[str, Any] ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(action) -> Union[str, AgentScratchpadUnit.Action]: action_name = None diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 90aa7b5fd4..8d25863a91 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel): identity: AgentStrategyIdentity parameters: list[AgentStrategyParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The description of the agent strategy") - output_schema: dict | None = None + output_schema: dict[str, Any] | None = None features: list[AgentFeature] | None = None meta_version: str | None = None # pydantic configs diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 7d1b11c008..c8ec7cb44d 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager: @classmethod def validate_and_set_defaults( - cls, tenant_id: str, config: dict, only_structure_validate: bool = False - ) -> tuple[dict, list[str]]: + cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False + ) -> tuple[dict[str, Any], list[str]]: if not config.get("sensitive_word_avoidance"): config["sensitive_word_avoidance"] = {"enabled": False} diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index f04a8df119..3d857a4e9c 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -138,7 +138,9 @@ class DatasetConfigManager: ) @classmethod - def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults( + cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any] + ) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for dataset feature @@ -172,7 +174,7 @@ class DatasetConfigManager: return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] @classmethod - def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict): + def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]): """ Extract dataset config for legacy compatibility diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index b7dd55632e..5df3df2b3e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,14 +1,13 @@ from typing import cast -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 5cc385c378..02498c23e1 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,10 +1,9 @@ from collections.abc import Mapping from typing import Any -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType - from core.app.app_config.entities import ModelConfigEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID @@ -41,7 +40,7 @@ class ModelConfigManager: ) @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for model config @@ -108,7 +107,7 @@ class ModelConfigManager: return dict(config), ["model"] @classmethod - def validate_model_completion_params(cls, cp: dict): + def validate_model_completion_params(cls, cp: dict[str, Any]): # model.completion_params if not isinstance(cp, dict): raise ValueError("model.completion_params must be of object type") diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 76196e7034..4c07445df3 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,7 +1,5 @@ from typing import Any -from graphon.model_runtime.entities.message_entities import PromptMessageRole - from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -9,6 +7,7 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict @@ -65,7 +64,7 @@ class PromptTemplateConfigManager: ) @classmethod - def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate pre_prompt and set defaults for prompt feature depending on the config['model'] @@ -130,7 +129,7 @@ class PromptTemplateConfigManager: return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] @classmethod - def validate_post_prompt_and_set_defaults(cls, config: dict): + def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]): """ Validate post_prompt and set defaults for prompt feature diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index f0b71c5801..ddb500cccf 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,10 +1,9 @@ import re -from typing import cast - -from graphon.variables.input_entities import VariableEntity, VariableEntityType +from typing import Any, cast from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( @@ -82,7 +81,7 @@ class BasicVariablesConfigManager: return variable_entities, external_data_variables @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for user input form @@ -99,7 +98,7 @@ class BasicVariablesConfigManager: return config, related_config_keys @classmethod - def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for user input form @@ -164,7 +163,9 @@ class BasicVariablesConfigManager: return config, ["user_input_form"] @classmethod - def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + def validate_external_data_tools_and_set_defaults( + cls, tenant_id: str, config: dict[str, Any] + ) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for external data fetch feature diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 819aca864c..53563dc5da 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,14 +1,14 @@ from enum import StrEnum, auto from typing import Any, Literal -from graphon.file import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.entities import MetadataFilteringCondition +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index e96517c426..8f20ef2ff9 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,9 +1,8 @@ from collections.abc import Mapping from typing import Any -from graphon.file import FileUploadConfig - from constants import DEFAULT_FILE_NUMBER_LIMITS +from graphon.file import FileUploadConfig class FileUploadConfigManager: @@ -30,7 +29,7 @@ class FileUploadConfigManager: return FileUploadConfig.model_validate(file_upload_dict) @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for file upload feature diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py index ef71bb348a..b167c04ab5 100644 --- a/api/core/app/app_config/features/more_like_this/manager.py +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, ConfigDict, Field, ValidationError @@ -13,7 +15,7 @@ class AppConfigModel(BaseModel): class MoreLikeThisConfigManager: @classmethod - def convert(cls, config: dict) -> bool: + def convert(cls, config: dict[str, Any]) -> bool: """ Convert model config to model config @@ -23,7 +25,7 @@ class MoreLikeThisConfigManager: return AppConfigModel.model_validate(validated_config).more_like_this.enabled @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: try: return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"] except ValidationError: diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index 92b4185abf..33f5aec183 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,6 +1,9 @@ +from typing import Any + + class OpeningStatementConfigManager: @classmethod - def convert(cls, config: dict) -> tuple[str, list]: + def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]: """ Convert model config to model config @@ -15,7 +18,7 @@ class OpeningStatementConfigManager: return opening_statement, suggested_questions_list @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for opening statement feature diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py index d098abac2f..8157fb41db 100644 --- a/api/core/app/app_config/features/retrieval_resource/manager.py +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -1,6 +1,9 @@ +from typing import Any + + class RetrievalResourceConfigManager: @classmethod - def convert(cls, config: dict) -> bool: + def convert(cls, config: dict[str, Any]) -> bool: show_retrieve_source = False retriever_resource_dict = config.get("retriever_resource") if retriever_resource_dict: @@ -10,7 +13,7 @@ class RetrievalResourceConfigManager: return show_retrieve_source @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for retriever resource feature diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py index e10ae03e04..679b8c343b 100644 --- a/api/core/app/app_config/features/speech_to_text/manager.py +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -1,6 +1,9 @@ +from typing import Any + + class SpeechToTextConfigManager: @classmethod - def convert(cls, config: dict) -> bool: + def convert(cls, config: dict[str, Any]) -> bool: """ Convert model config to model config @@ -15,7 +18,7 @@ class SpeechToTextConfigManager: return speech_to_text @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for speech to text feature diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py index 9ac5114d12..0c36992c77 100644 --- a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -1,6 +1,11 @@ +from typing import Any + +CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH = 1000 + + class SuggestedQuestionsAfterAnswerConfigManager: @classmethod - def convert(cls, config: dict) -> bool: + def convert(cls, config: dict[str, Any]) -> bool: """ Convert model config to model config @@ -15,9 +20,13 @@ class SuggestedQuestionsAfterAnswerConfigManager: return suggested_questions_after_answer @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ - Validate and set defaults for suggested questions feature + Validate and set defaults for suggested questions feature. + + Optional fields: + - prompt: custom instruction prompt. + - model: provider/model configuration for suggested question generation. :param config: app model config args """ @@ -36,4 +45,27 @@ class SuggestedQuestionsAfterAnswerConfigManager: if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): raise ValueError("enabled in suggested_questions_after_answer must be of boolean type") + prompt = config["suggested_questions_after_answer"].get("prompt") + if prompt is not None and not isinstance(prompt, str): + raise ValueError("prompt in suggested_questions_after_answer must be of string type") + if isinstance(prompt, str) and len(prompt) > CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH: + raise ValueError( + f"prompt in suggested_questions_after_answer must be less than or equal to " + f"{CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH} characters" + ) + + if "model" in config["suggested_questions_after_answer"]: + model_config = config["suggested_questions_after_answer"]["model"] + if not isinstance(model_config, dict): + raise ValueError("model in suggested_questions_after_answer must be of object type") + + if "provider" not in model_config or not isinstance(model_config["provider"], str): + raise ValueError("provider in suggested_questions_after_answer.model must be of string type") + + if "name" not in model_config or not isinstance(model_config["name"], str): + raise ValueError("name in suggested_questions_after_answer.model must be of string type") + + if "completion_params" in model_config and not isinstance(model_config["completion_params"], dict): + raise ValueError("completion_params in suggested_questions_after_answer.model must be of object type") + return config, ["suggested_questions_after_answer"] diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py index 1c75981785..ca84ec9c3b 100644 --- a/api/core/app/app_config/features/text_to_speech/manager.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -1,9 +1,11 @@ +from typing import Any + from core.app.app_config.entities import TextToSpeechEntity class TextToSpeechConfigManager: @classmethod - def convert(cls, config: dict): + def convert(cls, config: dict[str, Any]): """ Convert model config to model config @@ -22,7 +24,7 @@ class TextToSpeechConfigManager: return text_to_speech @classmethod - def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """ Validate and set defaults for text to speech feature diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 62e0c31d1a..13ace32fd6 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,8 +1,7 @@ import re -from graphon.variables.input_entities import VariableEntity - from core.app.app_config.entities import RagPipelineVariableEntity +from graphon.variables.input_entities import VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 985ded0f74..b79d5514b4 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -18,11 +18,6 @@ from constants import UUID_NIL if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader - from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner @@ -39,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse +from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, +) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager @@ -48,6 +47,10 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom @@ -656,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user: Account | EndUser, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, - ) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]: + ) -> ( + ChatbotAppBlockingResponse + | AdvancedChatPausedBlockingResponse + | Generator[ChatbotAppStreamResponse, None, None] + ): """ Handle response. :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 7b4cb98bd4..4e57b4dedc 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,12 +3,6 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels import RedisChannel -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import Variable from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -43,6 +37,12 @@ from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from models import Workflow from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 5c9bc43992..7cb0c9a8d3 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -3,7 +3,7 @@ from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( - AppBlockingResponse, + AdvancedChatPausedBlockingResponse, AppStreamResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, @@ -12,22 +12,40 @@ from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeStartStreamResponse, PingStreamResponse, + StreamEvent, ) -class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = ChatbotAppBlockingResponse - +class AdvancedChatAppGenerateResponseConverter( + AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse] +): @classmethod - def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_full_response( + cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking full response. :param blocking_response: blocking response :return: """ - blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) + if isinstance(blocking_response, AdvancedChatPausedBlockingResponse): + paused_data = blocking_response.data.model_dump(mode="json") + return { + "event": StreamEvent.WORKFLOW_PAUSED.value, + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, + "workflow_run_id": blocking_response.data.workflow_run_id, + "data": paused_data, + } + response = { - "event": "message", + "event": StreamEvent.MESSAGE.value, "task_id": blocking_response.task_id, "id": blocking_response.data.id, "message_id": blocking_response.data.message_id, @@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_simple_response( + cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking simple response. :param blocking_response: blocking response @@ -50,14 +70,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): response = cls.convert_blocking_full_response(blocking_response) metadata = response.get("metadata", {}) - response["metadata"] = cls._get_simple_metadata(metadata) + if isinstance(metadata, dict): + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, Any, None]: + ) -> Generator[dict[str, Any] | str, Any, None]: """ Convert stream full response. :param stream_response: stream response @@ -88,7 +109,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, Any, None]: + ) -> Generator[dict[str, Any] | str, Any, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 0ce9ddce9e..82dbf5381d 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -9,12 +9,6 @@ from datetime import datetime from threading import Thread from typing import Any, Union -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -59,14 +53,18 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, ErrorStreamResponse, + HumanInputRequiredPauseReasonPayload, + HumanInputRequiredResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, PingStreamResponse, StreamResponse, + WorkflowPauseStreamResponse, WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -77,6 +75,12 @@ from core.repositories.human_input_repository import HumanInputFormRepositoryImp from core.workflow.file_reference import resolve_file_record_id from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus @@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): if message.status == MessageStatus.PAUSED and message.answer: self._task_state.answer = message.answer - def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: + def process( + self, + ) -> Union[ + ChatbotAppBlockingResponse, + AdvancedChatPausedBlockingResponse, + Generator[ChatbotAppStreamResponse, None, None], + ]: """ Process generate task pipeline. :return: @@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]: """ Process blocking response. :return: """ + human_input_responses: list[HumanInputRequiredResponse] = [] for stream_response in generator: if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err + elif isinstance(stream_response, HumanInputRequiredResponse): + human_input_responses.append(stream_response) + elif isinstance(stream_response, WorkflowPauseStreamResponse): + return AdvancedChatPausedBlockingResponse( + task_id=stream_response.task_id, + data=AdvancedChatPausedBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + workflow_run_id=stream_response.data.workflow_run_id, + answer=self._task_state.answer, + metadata=self._message_end_to_stream_response().metadata, + created_at=self._message_created_at, + paused_nodes=stream_response.data.paused_nodes, + reasons=stream_response.data.reasons, + status=stream_response.data.status, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + ), + ) elif isinstance(stream_response, MessageEndStreamResponse): extras = {} if stream_response.metadata: @@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: continue + if human_input_responses: + return self._build_paused_blocking_response_from_human_input(human_input_responses) + raise ValueError("queue listening stopped unexpectedly.") + def _build_paused_blocking_response_from_human_input( + self, human_input_responses: list[HumanInputRequiredResponse] + ) -> AdvancedChatPausedBlockingResponse: + runtime_state = self._resolve_graph_runtime_state() + paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses)) + reasons = [ + HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json") + for response in human_input_responses + ] + + return AdvancedChatPausedBlockingResponse( + task_id=self._application_generate_entity.task_id, + data=AdvancedChatPausedBlockingResponse.Data( + id=self._message_id, + mode=self._conversation_mode, + conversation_id=self._conversation_id, + message_id=self._message_id, + workflow_run_id=human_input_responses[-1].workflow_run_id, + answer=self._task_state.answer, + metadata=self._message_end_to_stream_response().metadata, + created_at=self._message_created_at, + paused_nodes=paused_nodes, + reasons=reasons, + status=WorkflowExecutionStatus.PAUSED, + elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at, + total_tokens=runtime_state.total_tokens, + total_steps=runtime_state.node_run_steps, + ), + ) + def _to_stream_response( self, generator: Generator[StreamResponse, None, None] ) -> Generator[ChatbotAppStreamResponse, Any, None]: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 5872f6b264..5cdc477028 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, In from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index a20d3f3c38..cae0eee0df 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,9 +1,6 @@ import logging from typing import cast -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -19,6 +16,9 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 0c146c388f..03bc0a9108 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast + +from pydantic import JsonValue from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -12,11 +14,9 @@ from core.app.entities.task_entities import ( ) -class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = ChatbotAppBlockingResponse - +class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking simple response. :param blocking_response: blocking response @@ -56,7 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -87,7 +87,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 6e5a86505c..abcbb2f943 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -1,19 +1,22 @@ import logging from abc import ABC, abstractmethod from collections.abc import Generator, Mapping -from typing import Any, Union +from typing import Any, Union, cast -from graphon.model_runtime.errors.invoke import InvokeError +from pydantic import JsonValue from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) -class AppGenerateResponseConverter(ABC): - _blocking_response_type: type[AppBlockingResponse] +class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC): + @classmethod + def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse: + return cast(TBlockingResponse, response) @classmethod def convert( @@ -21,45 +24,45 @@ class AppGenerateResponseConverter(ABC): ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if isinstance(response, AppBlockingResponse): - return cls.convert_blocking_full_response(response) + return cls.convert_blocking_full_response(cls._cast_blocking_response(response)) else: - def _generate_full_response() -> Generator[dict | str, Any, None]: + def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]: yield from cls.convert_stream_full_response(response) return _generate_full_response() else: if isinstance(response, AppBlockingResponse): - return cls.convert_blocking_simple_response(response) + return cls.convert_blocking_simple_response(cls._cast_blocking_response(response)) else: - def _generate_simple_response() -> Generator[dict | str, Any, None]: + def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]: yield from cls.convert_stream_simple_response(response) return _generate_simple_response() @classmethod @abstractmethod - def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]: raise NotImplementedError @classmethod @abstractmethod - def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]: + def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]: raise NotImplementedError @classmethod @abstractmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: raise NotImplementedError @classmethod @abstractmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: raise NotImplementedError @classmethod @@ -107,13 +110,13 @@ class AppGenerateResponseConverter(ABC): return metadata @classmethod - def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]: + def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]: """ Error to stream response. :param e: exception :return: """ - error_responses: dict[type[Exception], dict[str, Any]] = { + error_responses: dict[type[Exception], dict[str, JsonValue]] = { ValueError: {"code": "invalid_param", "status": 400}, ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, QuotaExceededError: { @@ -127,7 +130,7 @@ class AppGenerateResponseConverter(ABC): } # Determine the response based on the type of exception - data: dict[str, Any] | None = None + data: dict[str, JsonValue] | None = None for k, v in error_responses.items(): if isinstance(e, k): data = v diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 7eccd59d17..8e8ccf2b90 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -2,9 +2,6 @@ from collections.abc import Generator, Mapping, Sequence from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final -from graphon.enums import NodeType -from graphon.file import File, FileUploadConfig -from graphon.variables.input_entities import VariableEntityType from sqlalchemy.orm import Session from core.app.apps.draft_variable_saver import ( @@ -16,6 +13,9 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope from extensions.ext_database import db from factories import file_factory +from graphon.enums import NodeType +from graphon.file import File, FileUploadConfig +from graphon.variables.input_entities import VariableEntityType from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 20bf81aeec..d1771452c5 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -7,7 +7,6 @@ from enum import IntEnum, auto from typing import Any from cachetools import TTLCache, cachedmethod -from graphon.runtime import GraphRuntimeState from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -22,6 +21,7 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from extensions.ext_redis import redis_client +from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 4aebc0cb30..1251b397e2 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,17 +5,6 @@ from collections.abc import Generator, Mapping, Sequence from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError - from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( @@ -41,6 +30,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager from extensions.ext_database import db +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 891dcece73..58afefe296 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, copy_current_request_context, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeF from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account from models.model import App, EndUser from services.conversation_service import ConversationService diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 050f763e95..077c5239f3 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,8 +1,6 @@ import logging from typing import cast -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -18,6 +16,8 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Conversation, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index f23ee7f89f..26efcbfafd 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast + +from pydantic import JsonValue from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -12,11 +14,9 @@ from core.app.entities.task_entities import ( ) -class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = ChatbotAppBlockingResponse - +class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking full response. :param blocking_response: blocking response @@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking simple response. :param blocking_response: blocking response @@ -56,7 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, @@ -87,7 +87,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "conversation_id": chunk.conversation_id, "message_id": chunk.message_id, diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index ab277857fe..2a90fbdad0 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,9 +4,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from graphon.runtime import GraphRuntimeState - from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.runtime import GraphRuntimeState if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index a515531616..7bab3f7bff 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -6,19 +6,6 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, TypedDict, Union -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import FILE_MODEL_IDENTITY, File -from graphon.runtime import GraphRuntimeState -from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment -from graphon.variables.variables import Variable -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from sqlalchemy.orm import Session @@ -65,9 +52,23 @@ from core.tools.tool_manager import ToolManager from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.human_input import HumanInputForm @@ -336,7 +337,26 @@ class WorkflowResponseConverter: except (TypeError, json.JSONDecodeError): definition_payload = {} display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) - form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session) + form_token_by_form_id = load_form_tokens_by_form_id( + human_input_form_ids, + session=session, + surface=( + HumanInputSurface.SERVICE_API + if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API + else None + ), + ) + + # Reconnect paths must preserve the same pause-reason contract as live streams; + # otherwise clients see schema drift after resume. + pause_reasons = enrich_human_input_pause_reasons( + pause_reasons, + form_tokens_by_form_id=form_token_by_form_id, + expiration_times_by_form_id={ + form_id: int(expiration_time.timestamp()) + for form_id, expiration_time in expiration_times_by_form_id.items() + }, + ) responses: list[StreamResponse] = [] diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 61339b316a..423bfdac51 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -6,7 +6,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, overload from flask import Flask, copy_current_request_context, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from sqlalchemy import select @@ -24,6 +23,7 @@ from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, I from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index b216f7cf7b..6bb1ecdcb1 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,8 +1,6 @@ import logging from typing import cast -from graphon.file import File -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -16,6 +14,8 @@ from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.model import App, Message logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index a4f574642d..ad978f58e0 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast + +from pydantic import JsonValue from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -12,17 +14,15 @@ from core.app.entities.task_entities import ( ) -class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = CompletionAppBlockingResponse - +class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): """ Convert blocking full response. :param blocking_response: blocking response :return: """ - response = { + response: dict[str, Any] = { "event": "message", "task_id": blocking_response.task_id, "id": blocking_response.data.id, @@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): return response @classmethod - def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): """ Convert blocking simple response. :param blocking_response: blocking response @@ -55,7 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "message_id": chunk.message_id, "created_at": chunk.created_at, @@ -85,7 +85,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): yield "ping" continue - response_chunk = { + response_chunk: dict[str, JsonValue] = { "event": sub_stream_response.event.value, "message_id": chunk.message_id, "created_at": chunk.created_at, diff --git a/api/core/app/apps/message_generator.py b/api/core/app/apps/message_generator.py index 68631bb230..c04f20c796 100644 --- a/api/core/app/apps/message_generator.py +++ b/api/core/app/apps/message_generator.py @@ -1,6 +1,7 @@ -from collections.abc import Callable, Generator, Mapping +from collections.abc import Callable, Generator, Iterable, Mapping from core.app.apps.streaming_utils import stream_topic_events +from core.app.entities.task_entities import StreamEvent from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic from models.model import AppMode @@ -26,6 +27,7 @@ class MessageGenerator: idle_timeout=300, ping_interval: float = 10.0, on_subscribe: Callable[[], None] | None = None, + terminal_events: Iterable[str | StreamEvent] | None = None, ) -> Generator[Mapping | str, None, None]: topic = cls.get_response_topic(app_mode, workflow_run_id) return stream_topic_events( @@ -33,4 +35,5 @@ class MessageGenerator: idle_timeout=idle_timeout, ping_interval=ping_interval, on_subscribe=on_subscribe, + terminal_events=terminal_events, ) diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py index cfacd8640d..3913657ae8 100644 --- a/api/core/app/apps/pipeline/generate_response_converter.py +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -13,11 +13,9 @@ from core.app.entities.task_entities import ( ) -class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = WorkflowAppBlockingResponse - +class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]): @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: """ Convert blocking full response. :param blocking_response: blocking response @@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): return dict(blocking_response.model_dump()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: """ Convert blocking simple response. :param blocking_response: blocking response @@ -37,7 +35,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -66,7 +64,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/pipeline/pipeline_config_manager.py b/api/core/app/apps/pipeline/pipeline_config_manager.py index 72b7f4bef6..8bbd745538 100644 --- a/api/core/app/apps/pipeline/pipeline_config_manager.py +++ b/api/core/app/apps/pipeline/pipeline_config_manager.py @@ -1,3 +1,5 @@ +from typing import Any + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig @@ -34,7 +36,9 @@ class PipelineConfigManager(BaseAppConfigManager): return pipeline_config @classmethod - def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: + def config_validate( + cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False + ) -> dict[str, Any]: """ Validate for pipeline config diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 139c7e73e0..4a76d0809e 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -10,8 +10,6 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, cast, overload from flask import Flask, current_app -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -29,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity -from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.app.entities.task_entities import ( + WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, + WorkflowAppStreamResponse, +) from core.datasource.entities.datasource_entities import ( DatasourceProviderType, OnlineDriveBrowseFilesRequest, @@ -43,6 +45,8 @@ from core.repositories.factory import ( WorkflowNodeExecutionRepository, ) from extensions.ext_database import db +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline @@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator): user: Account | EndUser, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, - ) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]: + ) -> ( + WorkflowAppBlockingResponse + | WorkflowAppPausedBlockingResponse + | Generator[WorkflowAppStreamResponse, None, None] + ): """ Handle response. :param application_generate_entity: application generate entity @@ -782,7 +790,7 @@ class PipelineGenerator(BaseAppGenerator): user_id: str, all_files: list, datasource_info: Mapping[str, Any], - next_page_parameters: dict | None = None, + next_page_parameters: dict[str, Any] | None = None, ): """ Get files in a folder. diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 36daaf09e9..2ee0ae27eb 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,12 +2,6 @@ import logging import time from typing import cast -from graphon.enums import WorkflowType -from graphon.graph import Graph -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader -from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -26,6 +20,12 @@ from core.workflow.system_variables import build_bootstrap_variables, build_syst from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from models.dataset import Document, Pipeline from models.model import EndUser from models.workflow import Workflow diff --git a/api/core/app/apps/streaming_utils.py b/api/core/app/apps/streaming_utils.py index af3441aca3..5743bad4b6 100644 --- a/api/core/app/apps/streaming_utils.py +++ b/api/core/app/apps/streaming_utils.py @@ -59,7 +59,7 @@ def stream_topic_events( def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]: - if not terminal_events: + if terminal_events is None: return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value} values: set[str] = set() for item in terminal_events: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6074e81d1e..e811c2b2e0 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,10 +8,6 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, overload from flask import Flask, current_app -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError -from graphon.runtime import GraphRuntimeState -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -29,7 +25,11 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse +from core.app.entities.task_entities import ( + WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, + WorkflowAppStreamResponse, +) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args @@ -38,6 +38,10 @@ from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from libs.flask_utils import preserve_flask_contexts from models.account import Account from models.enums import WorkflowRunTriggeredFrom @@ -612,7 +616,11 @@ class WorkflowAppGenerator(BaseAppGenerator): user: Account | EndUser, draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, - ) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]: + ) -> ( + WorkflowAppBlockingResponse + | WorkflowAppPausedBlockingResponse + | Generator[WorkflowAppStreamResponse, None, None] + ): """ Handle response. :param application_generate_entity: application generate entity diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 2cb8088971..cfb9208486 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -3,12 +3,6 @@ import time from collections.abc import Sequence from typing import cast -from graphon.enums import WorkflowType -from graphon.graph_engine.command_channels import RedisChannel -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import VariableLoader - from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -21,6 +15,11 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader from libs.datetime_utils import naive_utc_now from models.workflow import Workflow diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index c64f44a603..4037388798 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Any, cast from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -9,24 +9,29 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, WorkflowAppStreamResponse, ) -class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): - _blocking_response_type = WorkflowAppBlockingResponse - +class WorkflowAppGenerateResponseConverter( + AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse] +): @classmethod - def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] + def convert_blocking_full_response( + cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking full response. :param blocking_response: blocking response :return: """ - return blocking_response.model_dump() + return dict(blocking_response.model_dump()) @classmethod - def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override] + def convert_blocking_simple_response( + cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse + ) -> dict[str, Any]: """ Convert blocking simple response. :param blocking_response: blocking response @@ -37,7 +42,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -66,7 +71,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): @classmethod def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] - ) -> Generator[dict | str, None, None]: + ) -> Generator[dict[str, Any] | str, None, None]: """ Convert stream simple response. :param stream_response: stream response diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index f1b8b08eaa..87d9b73078 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,9 +4,6 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Union -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState from sqlalchemy.orm import Session, sessionmaker from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -45,12 +42,15 @@ from core.app.entities.queue_entities import ( ) from core.app.entities.task_entities import ( ErrorStreamResponse, + HumanInputRequiredPauseReasonPayload, + HumanInputRequiredResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, PingStreamResponse, StreamResponse, TextChunkStreamResponse, WorkflowAppBlockingResponse, + WorkflowAppPausedBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, WorkflowPauseStreamResponse, @@ -61,6 +61,9 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from models import Account from models.enums import CreatorUserRole from models.model import EndUser @@ -118,7 +121,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state - def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: + def process( + self, + ) -> Union[ + WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None] + ]: """ Process generate task pipeline. :return: @@ -129,19 +136,24 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]: """ To blocking response. :return: """ + human_input_responses: list[HumanInputRequiredResponse] = [] for stream_response in generator: if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err + elif isinstance(stream_response, HumanInputRequiredResponse): + human_input_responses.append(stream_response) elif isinstance(stream_response, WorkflowPauseStreamResponse): - response = WorkflowAppBlockingResponse( + return WorkflowAppPausedBlockingResponse( task_id=self._application_generate_entity.task_id, workflow_run_id=stream_response.data.workflow_run_id, - data=WorkflowAppBlockingResponse.Data( + data=WorkflowAppPausedBlockingResponse.Data( id=stream_response.data.workflow_run_id, workflow_id=self._workflow.id, status=stream_response.data.status, @@ -152,12 +164,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): total_steps=stream_response.data.total_steps, created_at=stream_response.data.created_at, finished_at=None, + paused_nodes=stream_response.data.paused_nodes, + reasons=stream_response.data.reasons, ), ) - return response elif isinstance(stream_response, WorkflowFinishStreamResponse): - response = WorkflowAppBlockingResponse( + return WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, workflow_run_id=stream_response.data.id, data=WorkflowAppBlockingResponse.Data( @@ -174,12 +187,44 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): ), ) - return response else: continue + if human_input_responses: + return self._build_paused_blocking_response_from_human_input(human_input_responses) + raise ValueError("queue listening stopped unexpectedly.") + def _build_paused_blocking_response_from_human_input( + self, human_input_responses: list[HumanInputRequiredResponse] + ) -> WorkflowAppPausedBlockingResponse: + runtime_state = self._resolve_graph_runtime_state() + paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses)) + created_at = int(runtime_state.start_at) + reasons = [ + HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json") + for response in human_input_responses + ] + + return WorkflowAppPausedBlockingResponse( + task_id=self._application_generate_entity.task_id, + workflow_run_id=human_input_responses[-1].workflow_run_id, + data=WorkflowAppPausedBlockingResponse.Data( + id=human_input_responses[-1].workflow_run_id, + workflow_id=self._workflow.id, + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + error=None, + elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at, + total_tokens=runtime_state.total_tokens, + total_steps=runtime_state.node_run_steps, + created_at=created_at, + finished_at=None, + paused_nodes=paused_nodes, + reasons=reasons, + ), + ) + def _to_stream_response( self, generator: Generator[StreamResponse, None, None] ) -> Generator[WorkflowAppStreamResponse, None, None]: @@ -682,15 +727,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None): invoke_from = self._application_generate_entity.invoke_from - if invoke_from == InvokeFrom.SERVICE_API: - created_from = WorkflowAppLogCreatedFrom.SERVICE_API - elif invoke_from == InvokeFrom.EXPLORE: - created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP - elif invoke_from == InvokeFrom.WEB_APP: - created_from = WorkflowAppLogCreatedFrom.WEB_APP - else: - # not save log for debugging - return + match invoke_from: + case InvokeFrom.SERVICE_API: + created_from = WorkflowAppLogCreatedFrom.SERVICE_API + case InvokeFrom.EXPLORE: + created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP + case InvokeFrom.WEB_APP: + created_from = WorkflowAppLogCreatedFrom.WEB_APP + case InvokeFrom.DEBUGGER | InvokeFrom.TRIGGER | InvokeFrom.PUBLISHED_PIPELINE | InvokeFrom.VALIDATION: + # not save log for debugging + return if not workflow_run_id: return diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 437432611d..047b54c86c 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,39 +3,6 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph import Graph -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import ( - GraphEngineEvent, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from pydantic import ValidationError from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -82,6 +49,39 @@ from core.workflow.system_variables import ( from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunIterationFailedEvent, + NodeRunIterationNextEvent, + NodeRunIterationStartedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunLoopNextEvent, + NodeRunLoopStartedEvent, + NodeRunLoopSucceededEvent, + NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index a3fb7b4c5d..09992f4bbf 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -2,13 +2,13 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import TYPE_CHECKING, Any -from graphon.file import File, FileUploadConfig -from graphon.model_runtime.entities.model_entities import AIModelEntity from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 482f995d8e..221b7fb058 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,14 +3,14 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import PauseReason -from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities import RetrievalSourceMetadata +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReason +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 62df85b13f..ad05566521 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,15 +1,16 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any +from typing import Any, Literal -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.nodes.human_input.entities import FormInput, UserAction -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, JsonValue from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities import RetrievalSourceMetadata +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReasonType +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): @@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse): data: Data +class HumanInputRequiredPauseReasonPayload(BaseModel): + """ + Public pause-reason payload used by blocking responses when only + ``human_input_required`` events are available. + """ + + TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED + form_id: str + node_id: str + node_title: str + form_content: str + inputs: Sequence[FormInput] = Field(default_factory=list) + actions: Sequence[UserAction] = Field(default_factory=list) + display_in_ui: bool = False + form_token: str | None = None + resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) + expiration_time: int + + @classmethod + def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload": + return cls( + form_id=data.form_id, + node_id=data.node_id, + node_title=data.node_title, + form_content=data.form_content, + inputs=data.inputs, + actions=data.actions, + display_in_ui=data.display_in_ui, + form_token=data.form_token, + resolved_default_values=data.resolved_default_values, + expiration_time=data.expiration_time, + ) + + class HumanInputFormFilledResponse(StreamResponse): class Data(BaseModel): """ @@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse): workflow_run_id: str data: Data - def to_ignore_detail_dict(self): + def to_ignore_detail_dict(self) -> dict[str, JsonValue]: return { "event": self.event.value, "task_id": self.task_id, @@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse): workflow_run_id: str data: Data - def to_ignore_detail_dict(self): + def to_ignore_detail_dict(self) -> dict[str, JsonValue]: return { "event": self.event.value, "task_id": self.task_id, @@ -521,7 +556,7 @@ class IterationNodeStartStreamResponse(StreamResponse): node_type: str title: str created_at: int - extras: dict = Field(default_factory=dict) + extras: dict[str, Any] = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} inputs_truncated: bool = False @@ -547,7 +582,7 @@ class IterationNodeNextStreamResponse(StreamResponse): title: str index: int created_at: int - extras: dict = Field(default_factory=dict) + extras: dict[str, Any] = Field(default_factory=dict) event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str @@ -571,7 +606,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse): outputs: Mapping | None = None outputs_truncated: bool = False created_at: int - extras: dict | None = None + extras: dict[str, Any] | None = None inputs: Mapping | None = None inputs_truncated: bool = False status: WorkflowNodeExecutionStatus @@ -602,7 +637,7 @@ class LoopNodeStartStreamResponse(StreamResponse): node_type: str title: str created_at: int - extras: dict = Field(default_factory=dict) + extras: dict[str, Any] = Field(default_factory=dict) metadata: Mapping = {} inputs: Mapping = {} inputs_truncated: bool = False @@ -653,7 +688,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse): outputs: Mapping | None = None outputs_truncated: bool = False created_at: int - extras: dict | None = None + extras: dict[str, Any] | None = None inputs: Mapping | None = None inputs_truncated: bool = False status: WorkflowNodeExecutionStatus @@ -774,6 +809,34 @@ class ChatbotAppBlockingResponse(AppBlockingResponse): data: Data +class AdvancedChatPausedBlockingResponse(AppBlockingResponse): + """ + ChatbotAppPausedBlockingResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + mode: str + conversation_id: str + message_id: str + workflow_run_id: str + answer: str + metadata: Mapping[str, object] = Field(default_factory=dict) + created_at: int + paused_nodes: Sequence[str] = Field(default_factory=list) + reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]]) + status: WorkflowExecutionStatus + elapsed_time: float + total_tokens: int + total_steps: int + + data: Data + + class CompletionAppBlockingResponse(AppBlockingResponse): """ CompletionAppBlockingResponse entity @@ -819,6 +882,33 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): data: Data +class WorkflowAppPausedBlockingResponse(AppBlockingResponse): + """ + WorkflowAppPausedBlockingResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + workflow_id: str + status: WorkflowExecutionStatus + outputs: Mapping[str, Any] | None = None + error: str | None = None + elapsed_time: float + total_tokens: int + total_steps: int + created_at: int + finished_at: int | None + paused_nodes: Sequence[str] = Field(default_factory=list) + reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list) + + workflow_run_id: str + data: Data + + class AgentLogStreamResponse(StreamResponse): """ AgentLogStreamResponse entity diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index d2d2fea4fb..d59f5125e3 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,9 +1,8 @@ import logging -from graphon.model_runtime.entities.message_entities import PromptMessage - from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation +from graphon.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py index 80d504ef1c..a583301f9b 100644 --- a/api/core/app/file_access/scope.py +++ b/api/core/app/file_access/scope.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Generator # Changed from Iterator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass @@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None: @contextmanager -def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]: +def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None] token = _current_file_access_scope.set(scope) try: yield diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index e09869f5f8..d5e6b04a4a 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -9,11 +9,10 @@ scope updates that matter to chat applications. import logging -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent - from core.workflow.system_variables import SystemVariableKey, get_system_text from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index c027f42788..9811f9f830 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,14 +1,14 @@ from dataclasses import dataclass from typing import Annotated, Literal, Self -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index 8c8daf8712..bb9fc1b6fa 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -3,10 +3,10 @@ import uuid from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore + from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent - from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index 77c7bec67e..b60fe82ffe 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -2,12 +2,12 @@ import logging from datetime import UTC, datetime from typing import Any, ClassVar -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from pydantic import TypeAdapter from core.db.session_factory import session_factory from core.workflow.system_variables import SystemVariableKey, get_system_text +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index 278d0cb30b..5631caa1a5 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -1,22 +1,35 @@ from __future__ import annotations +from copy import deepcopy from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.nodes.llm.entities import ModelConfig -from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from graphon.nodes.llm.protocols import CredentialsProvider - from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: + """Resolves and returns LLM credentials for a given provider and model. + + Fetched credentials are stored in :attr:`credentials_cache` and reused for + subsequent ``fetch`` calls for the same ``(provider_name, model_name)``. + Because of that cache, a single instance can return stale credentials after + the tenant or provider configuration changes (e.g. API key rotation). + + Do **not** keep one instance for the lifetime of a process or across + unrelated invocations. Create a new provider per request, workflow run, or + other bounded scope where up-to-date credentials matter. + """ + tenant_id: str provider_manager: ProviderManager + credentials_cache: dict[tuple[str, str], dict[str, Any]] def __init__( self, @@ -31,8 +44,12 @@ class DifyCredentialsProvider: user_id=run_context.user_id, ) self.provider_manager = provider_manager + self.credentials_cache = {} def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + if (provider_name, model_name) in self.credentials_cache: + return deepcopy(self.credentials_cache[(provider_name, model_name)]) + provider_configurations = self.provider_manager.get_configurations(self.tenant_id) provider_configuration = provider_configurations.get(provider_name) if not provider_configuration: @@ -47,6 +64,7 @@ class DifyCredentialsProvider: if credentials is None: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials) return credentials @@ -66,7 +84,8 @@ class DifyModelFactory: provider_manager=create_plugin_provider_manager( tenant_id=run_context.tenant_id, user_id=run_context.user_id, - ) + ), + enable_credentials_cache=True, ) self.model_manager = model_manager @@ -85,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro tenant_id=run_context.tenant_id, user_id=run_context.user_id, ) - model_manager = ModelManager(provider_manager=provider_manager) + model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True) return ( DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 0bb10190c4..b6039e1e4e 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,4 +1,3 @@ -from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update from sqlalchemy.orm import sessionmaker @@ -8,6 +7,7 @@ from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance from extensions.ext_database import db +from graphon.model_runtime.entities.llm_entities import LLMUsage from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 10b9c36d3e..9e688589db 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,7 +1,6 @@ import logging import time -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,6 +17,7 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 6bb177fe02..e2e07ebaff 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -4,13 +4,6 @@ from collections.abc import Generator from threading import Thread from typing import Any, cast -from graphon.file import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -60,6 +53,13 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index 77310baf74..1dd713821f 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,9 +1,8 @@ from typing import TypedDict +from core.tools.signature import sign_tool_file from graphon.file import FileTransferMethod from graphon.file import helpers as file_helpers - -from core.tools.signature import sign_tool_file from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 8604235ef2..3a6f9d575a 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -9,17 +9,17 @@ import urllib.parse from collections.abc import Generator from typing import TYPE_CHECKING, Literal -from graphon.file import FileTransferMethod -from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from graphon.file.runtime import set_workflow_file_runtime - from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol from core.db.session_factory import session_factory -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.tools.signature import sign_tool_file from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage +from graphon.file import FileTransferMethod +from graphon.file.protocols import WorkflowFileRuntimeProtocol +from graphon.file.runtime import set_workflow_file_runtime +from graphon.http.protocols import HttpResponseProtocol if TYPE_CHECKING: from graphon.file import File @@ -44,7 +44,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): return dify_config.MULTIMODAL_SEND_FORMAT def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - return ssrf_proxy.get(url, follow_redirects=follow_redirects) + return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects) def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index c577ce0754..4a7918032e 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -7,17 +7,16 @@ This layer centralizes model-quota deduction outside node implementations. import logging from typing import TYPE_CHECKING, cast, final, override +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm import deduct_llm_quota, ensure_llm_quota_available +from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes from graphon.graph_engine.entities.commands import AbortCommand, CommandType from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent from graphon.nodes.base.node import Node -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.llm import deduct_llm_quota, ensure_llm_quota_available -from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance - if TYPE_CHECKING: from graphon.nodes.llm.node import LLMNode from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 99e8015c0b..8b5a5b9d7f 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -12,10 +12,6 @@ from contextvars import Token from dataclasses import dataclass from typing import cast, final, override -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.graph_engine.layers import GraphEngineLayer -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node from opentelemetry import context as context_api from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context @@ -28,6 +24,10 @@ from extensions.otel.parser import ( ToolNodeOTelParser, ) from extensions.otel.runtime import is_instrument_flag_enabled +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index ada065a943..d521304615 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -14,6 +14,13 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Union +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from graphon.entities import WorkflowExecution, WorkflowNodeExecution from graphon.enums import ( WorkflowExecutionStatus, @@ -38,14 +45,6 @@ from graphon.graph_events import ( NodeRunSucceededEvent, ) from graphon.node_events import NodeRunResult - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from libs.datetime_utils import naive_utc_now @@ -350,7 +349,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): execution.total_tokens = runtime_state.total_tokens execution.total_steps = runtime_state.node_run_steps execution.outputs = execution.outputs or runtime_state.outputs - execution.exceptions_count = runtime_state.exceptions_count + execution.exceptions_count = max(execution.exceptions_count, runtime_state.exceptions_count) def _update_node_execution( self, diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index 3d8a7a54f3..9e3c187210 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -6,9 +6,6 @@ import re import threading from collections.abc import Iterable -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelType - from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentMessageEvent, @@ -18,6 +15,8 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType class AudioTrunk: diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index a5297fa33a..f0dcb13b62 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -3,9 +3,6 @@ from collections.abc import Generator from threading import Lock from typing import Any, cast -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from sqlalchemy import select import contexts @@ -31,6 +28,9 @@ from core.plugin.impl.datasource import PluginDatasourceManager from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from factories import file_factory +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService @@ -352,11 +352,11 @@ class DatasourceManager: raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") file_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.CUSTOM, + file_type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference(record_id=str(upload_file.id)), diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 890f1ca319..352e6bfd49 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ from typing import Any, Literal, TypedDict -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject, I18nObjectDict +from graphon.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): @@ -14,7 +14,7 @@ class DatasourceApiEntity(BaseModel): description: I18nObject parameters: list[DatasourceParameter] | None = None labels: list[str] = Field(default_factory=list) - output_schema: dict | None = None + output_schema: dict[str, Any] | None = None ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None @@ -30,7 +30,7 @@ class DatasourceProviderApiEntityDict(TypedDict): icon: str | dict label: I18nObjectDict type: str - team_credentials: dict | None + team_credentials: dict[str, Any] | None is_team_authorization: bool allow_delete: bool datasources: list[Any] @@ -45,8 +45,8 @@ class DatasourceProviderApiEntity(BaseModel): icon: str | dict label: I18nObject # label type: str - masked_credentials: dict | None = None - original_credentials: dict | None = None + masked_credentials: dict[str, Any] | None = None + original_credentials: dict[str, Any] | None = None is_team_authorization: bool = False allow_delete: bool = True plugin_id: str | None = Field(default="", description="The plugin id of the datasource") diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index f20bab53f0..443b503a69 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -2,7 +2,7 @@ from __future__ import annotations import enum from enum import StrEnum -from typing import Any +from typing import Any, TypedDict from pydantic import BaseModel, Field, ValidationInfo, field_validator from yarl import URL @@ -129,7 +129,7 @@ class DatasourceEntity(BaseModel): identity: DatasourceIdentity parameters: list[DatasourceParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The label of the datasource") - output_schema: dict | None = None + output_schema: dict[str, Any] | None = None @field_validator("parameters", mode="before") @classmethod @@ -179,6 +179,12 @@ class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity): datasources: list[DatasourceEntity] = Field(default_factory=list) +class DatasourceInvokeMetaDict(TypedDict): + time_cost: float + error: str | None + tool_config: dict[str, Any] | None + + class DatasourceInvokeMeta(BaseModel): """ Datasource invoke meta @@ -186,7 +192,7 @@ class DatasourceInvokeMeta(BaseModel): time_cost: float = Field(..., description="The time cost of the tool invoke") error: str | None = None - tool_config: dict | None = None + tool_config: dict[str, Any] | None = None @classmethod def empty(cls) -> DatasourceInvokeMeta: @@ -202,12 +208,13 @@ class DatasourceInvokeMeta(BaseModel): """ return cls(time_cost=0.0, error=error, tool_config={}) - def to_dict(self) -> dict: - return { + def to_dict(self) -> DatasourceInvokeMetaDict: + result: DatasourceInvokeMetaDict = { "time_cost": self.time_cost, "error": self.error, "tool_config": self.tool_config, } + return result class DatasourceLabel(BaseModel): @@ -235,7 +242,7 @@ class OnlineDocumentPage(BaseModel): page_id: str = Field(..., description="The page id") page_name: str = Field(..., description="The page title") - page_icon: dict | None = Field(None, description="The page icon") + page_icon: dict[str, Any] | None = Field(None, description="The page icon") type: str = Field(..., description="The type of the page") last_edited_time: str = Field(..., description="The last edited time") parent_id: str | None = Field(None, description="The parent page id") @@ -294,7 +301,7 @@ class GetWebsiteCrawlRequest(BaseModel): Get website crawl request """ - crawl_parameters: dict = Field(..., description="The crawl parameters") + crawl_parameters: dict[str, Any] = Field(..., description="The crawl parameters") class WebSiteInfoDetail(BaseModel): @@ -351,7 +358,7 @@ class OnlineDriveFileBucket(BaseModel): bucket: str | None = Field(None, description="The file bucket") files: list[OnlineDriveFile] = Field(..., description="The file list") is_truncated: bool = Field(False, description="Whether the result is truncated") - next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page") + next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page") class OnlineDriveBrowseFilesRequest(BaseModel): @@ -362,7 +369,7 @@ class OnlineDriveBrowseFilesRequest(BaseModel): bucket: str | None = Field(None, description="The file bucket") prefix: str = Field(..., description="The parent folder ID") max_keys: int = Field(20, description="Page size for pagination") - next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page") + next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page") class OnlineDriveBrowseFilesResponse(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index c012e128f4..6a3f9e684a 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -2,11 +2,10 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type -from graphon.file import File, FileTransferMethod, FileType - from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index d304c982cd..04ae193396 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any, TypeAlias -from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field +from graphon.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index b1ba3c3e2a..a13938f3fb 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field, field_validator @@ -37,7 +39,7 @@ class PipelineDocument(BaseModel): id: str position: int data_source_type: str - data_source_info: dict | None = None + data_source_info: dict[str, Any] | None = None name: str indexing_status: str error: str | None = None diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index a440829b46..bfa4f56915 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -6,7 +6,6 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse -from graphon.file import helpers as file_helpers from pydantic import BaseModel from configs import dify_config @@ -16,6 +15,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from graphon.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 84d95c38c6..e99a131500 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,10 +1,11 @@ from collections.abc import Sequence from enum import StrEnum, auto +from pydantic import BaseModel, ConfigDict + from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel from graphon.model_runtime.entities.provider_entities import ProviderEntity -from pydantic import BaseModel, ConfigDict class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index f3b2c31465..38b87e2cd1 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -6,17 +6,8 @@ import re from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError +from typing import Any -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.__base.ai_model import AIModel -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -33,6 +24,16 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -111,7 +112,7 @@ class ProviderConfiguration(BaseModel): return ModelProviderFactory(model_runtime=self._bound_model_runtime) return create_plugin_model_provider_factory(tenant_id=self.tenant_id) - def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: + def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None: """ Get current credentials. @@ -233,7 +234,7 @@ class ProviderConfiguration(BaseModel): return session.execute(stmt).scalar_one_or_none() - def _get_specific_provider_credential(self, credential_id: str) -> dict | None: + def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None: """ Get a specific provider credential by ID. :param credential_id: Credential ID @@ -297,7 +298,7 @@ class ProviderConfiguration(BaseModel): stmt = stmt.where(ProviderCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None - def get_provider_credential(self, credential_id: str | None = None) -> dict | None: + def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None: """ Get provider credentials. @@ -317,32 +318,28 @@ class ProviderConfiguration(BaseModel): else [], ) - def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None): + def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""): """ Validate custom credentials. :param credentials: provider credentials :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate - :param session: optional database session :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.provider_credential_schema.credential_form_schemas + if self.provider.provider_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.id == credential_id, ) - credential_record = s.execute(stmt).scalar_one_or_none() - # fix origin data + credential_record = session.execute(stmt).scalar_one_or_none() if credential_record and credential_record.encrypted_config: if not credential_record.encrypted_config.startswith("{"): original_credentials = {"openai_api_key": credential_record.encrypted_config} @@ -353,31 +350,23 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # encrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.provider_credentials_validate( + provider=self.provider.provider, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables and isinstance(value, str): + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def _generate_provider_credential_name(self, session) -> str: """ @@ -447,21 +436,23 @@ class ProviderConfiguration(BaseModel): provider_names.append(model_provider_id.provider_name) return provider_names - def create_provider_credential(self, credentials: dict, credential_name: str | None): + def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None): """ Add custom provider credentials. :param credentials: provider credentials :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: - if self._check_provider_credential_name_exists(credential_name=credential_name, session=session): + if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session): raise ValueError(f"Credential with name '{credential_name}' already exists.") else: - credential_name = self._generate_provider_credential_name(session) + credential_name = self._generate_provider_credential_name(pre_session) - credentials = self.validate_provider_credentials(credentials=credentials, session=session) + credentials = self.validate_provider_credentials(credentials=credentials) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) try: new_record = ProviderCredential( @@ -474,7 +465,6 @@ class ProviderConfiguration(BaseModel): session.flush() if not provider_record: - # If provider record does not exist, create it provider_record = Provider( tenant_id=self.tenant_id, provider_name=self.provider.provider, @@ -515,7 +505,7 @@ class ProviderConfiguration(BaseModel): def update_provider_credential( self, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ): @@ -527,15 +517,15 @@ class ProviderConfiguration(BaseModel): :param credential_name: credential name :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_provider_credential_name_exists( - credential_name=credential_name, session=session, exclude_id=credential_id + credential_name=credential_name, session=pre_session, exclude_id=credential_id ): raise ValueError(f"Credential with name '{credential_name}' already exists.") - credentials = self.validate_provider_credentials( - credentials=credentials, credential_id=credential_id, session=session - ) + credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id) + + with Session(db.engine) as session: provider_record = self._get_provider_record(session) stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, @@ -543,12 +533,10 @@ class ProviderConfiguration(BaseModel): ProviderCredential.provider_name.in_(self._get_provider_names()), ) - # Get the credential record to update credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: @@ -760,7 +748,7 @@ class ProviderConfiguration(BaseModel): def _get_specific_custom_model_credential( self, model_type: ModelType, model: str, credential_id: str - ) -> dict | None: + ) -> dict[str, Any] | None: """ Get a specific provider credential by ID. :param credential_id: Credential ID @@ -832,7 +820,9 @@ class ProviderConfiguration(BaseModel): stmt = stmt.where(ProviderModelCredential.id != exclude_id) return session.execute(stmt).scalar_one_or_none() is not None - def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None: + def get_custom_model_credential( + self, model_type: ModelType, model: str, credential_id: str | None + ) -> dict[str, Any] | None: """ Get custom model credentials. @@ -872,9 +862,8 @@ class ProviderConfiguration(BaseModel): self, model_type: ModelType, model: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str = "", - session: Session | None = None, ): """ Validate custom model credentials. @@ -885,16 +874,14 @@ class ProviderConfiguration(BaseModel): :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate :return: """ + provider_credential_secret_variables = self.extract_secret_variables( + self.provider.model_credential_schema.credential_form_schemas + if self.provider.model_credential_schema + else [] + ) - def _validate(s: Session): - # Get provider credential secret variables - provider_credential_secret_variables = self.extract_secret_variables( - self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema - else [] - ) - - if credential_id: + if credential_id: + with Session(db.engine) as session: try: stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, @@ -903,7 +890,7 @@ class ProviderConfiguration(BaseModel): ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type, ) - credential_record = s.execute(stmt).scalar_one_or_none() + credential_record = session.execute(stmt).scalar_one_or_none() original_credentials = ( json.loads(credential_record.encrypted_config) if credential_record and credential_record.encrypted_config @@ -912,34 +899,26 @@ class ProviderConfiguration(BaseModel): except JSONDecodeError: original_credentials = {} - # decrypt credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - credentials[key] = encrypter.decrypt_token( - tenant_id=self.tenant_id, token=original_credentials[key] - ) - - model_provider_factory = self.get_model_provider_factory() - validated_credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials - ) - - for key, value in validated_credentials.items(): + for key, value in credentials.items(): if key in provider_credential_secret_variables: - validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + if value == HIDDEN_VALUE and key in original_credentials: + credentials[key] = encrypter.decrypt_token( + tenant_id=self.tenant_id, token=original_credentials[key] + ) - return validated_credentials + model_provider_factory = self.get_model_provider_factory() + validated_credentials = model_provider_factory.model_credentials_validate( + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials + ) - if session: - return _validate(session) - else: - with Session(db.engine) as new_session: - return _validate(new_session) + for key, value in validated_credentials.items(): + if key in provider_credential_secret_variables and isinstance(value, str): + validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value) + + return validated_credentials def create_custom_model_credential( - self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None + self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None ) -> None: """ Create a custom model credential. @@ -949,20 +928,22 @@ class ProviderConfiguration(BaseModel): :param credentials: model credentials dict :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name: if self._check_custom_model_credential_name_exists( - model=model, model_type=model_type, credential_name=credential_name, session=session + model=model, model_type=model_type, credential_name=credential_name, session=pre_session ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") else: credential_name = self._generate_custom_model_credential_name( - model=model, model_type=model_type, session=session + model=model, model_type=model_type, session=pre_session ) - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, model=model, credentials=credentials, session=session - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, model=model, credentials=credentials + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) try: @@ -977,7 +958,6 @@ class ProviderConfiguration(BaseModel): session.add(credential) session.flush() - # save provider model if not provider_model_record: provider_model_record = ProviderModel( tenant_id=self.tenant_id, @@ -1002,7 +982,12 @@ class ProviderConfiguration(BaseModel): raise def update_custom_model_credential( - self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str + self, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + credential_name: str | None, + credential_id: str, ) -> None: """ Update a custom model credential. @@ -1014,23 +999,24 @@ class ProviderConfiguration(BaseModel): :param credential_id: credential id :return: """ - with Session(db.engine) as session: + with Session(db.engine) as pre_session: if credential_name and self._check_custom_model_credential_name_exists( model=model, model_type=model_type, credential_name=credential_name, - session=session, + session=pre_session, exclude_id=credential_id, ): raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.") - # validate custom model config - credentials = self.validate_custom_model_credentials( - model_type=model_type, - model=model, - credentials=credentials, - credential_id=credential_id, - session=session, - ) + + credentials = self.validate_custom_model_credentials( + model_type=model_type, + model=model, + credentials=credentials, + credential_id=credential_id, + ) + + with Session(db.engine) as session: provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session) stmt = select(ProviderModelCredential).where( @@ -1045,7 +1031,6 @@ class ProviderConfiguration(BaseModel): raise ValueError("Credential record not found.") try: - # Update credential credential_record.encrypted_config = json.dumps(credentials) credential_record.updated_at = naive_utc_now() if credential_name: @@ -1412,7 +1397,9 @@ class ProviderConfiguration(BaseModel): # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) - def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None: + def get_model_schema( + self, model_type: ModelType, model: str, credentials: dict[str, Any] | None + ) -> AIModelEntity | None: """ Get model schema """ @@ -1471,7 +1458,7 @@ class ProviderConfiguration(BaseModel): return secret_input_form_variables - def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]): + def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]): """ Obfuscated credentials. diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 2c8767a32b..72b29c2277 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,9 +1,8 @@ from __future__ import annotations from enum import StrEnum, auto -from typing import Union +from typing import Any, Union -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import ( @@ -13,6 +12,7 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): @@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel): enabled: bool current_quota_type: ProviderQuotaType | None = None quota_configurations: list[QuotaConfiguration] = [] - credentials: dict | None = None + credentials: dict[str, Any] | None = None class CustomProviderConfiguration(BaseModel): @@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel): Model class for provider custom configuration. """ - credentials: dict + credentials: dict[str, Any] current_credential_id: str | None = None current_credential_name: str | None = None available_credentials: list[CredentialConfiguration] = [] @@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel): model: str model_type: ModelType - credentials: dict | None + credentials: dict[str, Any] | None current_credential_id: str | None = None current_credential_name: str | None = None available_model_credentials: list[CredentialConfiguration] = [] @@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel): id: str name: str - credentials: dict + credentials: dict[str, Any] credential_source_type: str | None = None credential_id: str | None = None diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index f9e6099049..01139d07e2 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast import httpx @@ -14,7 +14,7 @@ class APIBasedExtensionRequestor: self.api_endpoint = api_endpoint self.api_key = api_key - def request(self, point: APIBasedExtensionPoint, params: dict): + def request(self, point: APIBasedExtensionPoint, params: dict[str, Any]) -> dict[str, Any]: """ Request the api. @@ -49,4 +49,4 @@ class APIBasedExtensionRequestor: if response.status_code != 200: raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}") - return cast(dict, response.json()) + return cast(dict[str, Any], response.json()) diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index c2789a7a35..c08e319aac 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -21,8 +21,8 @@ class ExtensionModule(StrEnum): class ModuleExtension(BaseModel): extension_class: Any | None = None name: str - label: dict | None = None - form_schema: list | None = None + label: dict[str, Any] | None = None + form_schema: list[dict[str, Any]] | None = None builtin: bool = True position: int | None = None @@ -32,9 +32,9 @@ class Extensible: name: str tenant_id: str - config: dict | None = None + config: dict[str, Any] | None = None - def __init__(self, tenant_id: str, config: dict | None = None): + def __init__(self, tenant_id: str, config: dict[str, Any] | None = None): self.tenant_id = tenant_id self.config = config diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 564801f189..8ce068cfbb 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any, TypedDict + from sqlalchemy import select from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor @@ -7,6 +10,16 @@ from extensions.ext_database import db from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +class ApiToolConfig(TypedDict, total=False): + """Expected config shape for ApiExternalDataTool. + + Not used directly in method signatures (base class accepts dict[str, Any]); + kept here to document the keys this tool reads from config. + """ + + api_based_extension_id: str + + class ApiExternalDataTool(ExternalDataTool): """ The api external data tool. @@ -16,7 +29,7 @@ class ApiExternalDataTool(ExternalDataTool): """the unique name of external data tool""" @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -37,7 +50,7 @@ class ApiExternalDataTool(ExternalDataTool): if not api_based_extension: raise ValueError("api_based_extension_id is invalid") - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/base.py b/api/core/external_data_tool/base.py index cbec2e4e42..12bea4e9e5 100644 --- a/api/core/external_data_tool/base.py +++ b/api/core/external_data_tool/base.py @@ -1,4 +1,6 @@ from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Any from core.extension.extensible import Extensible, ExtensionModule @@ -15,14 +17,14 @@ class ExternalDataTool(Extensible, ABC): variable: str """the tool variable name of app tool""" - def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None): + def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict[str, Any] | None = None): super().__init__(tenant_id, config) self.app_id = app_id self.variable = variable @classmethod @abstractmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -33,7 +35,7 @@ class ExternalDataTool(Extensible, ABC): raise NotImplementedError @abstractmethod - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str: """ Query the external data tool. diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 6c542d681b..f404aa7286 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension class ExternalDataToolFactory: - def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict): + def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]): extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) self.__extension_instance = extension_class( tenant_id=tenant_id, app_id=app_id, variable=variable, config=config ) @classmethod - def validate_config(cls, name: str, tenant_id: str, config: dict): + def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]) -> None: """ Validate the incoming form config data. diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 35bfcfb6a5..951e065b2c 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,7 +4,6 @@ from threading import Lock from typing import Any import httpx -from graphon.nodes.code.entities import CodeLanguage from pydantic import BaseModel from yarl import URL @@ -14,6 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client +from graphon.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b96a9ce380..38864a1830 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -102,7 +102,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode() + inputs_json_str = dumps_with_segments(inputs).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/helper/creators.py b/api/core/helper/creators.py new file mode 100644 index 0000000000..b01e16f18a --- /dev/null +++ b/api/core/helper/creators.py @@ -0,0 +1,41 @@ +""" +Helper module for Creators Platform integration. + +Provides functionality to upload DSL files to the Creators Platform +and generate redirect URLs with OAuth authorization codes. +""" + +import logging +from urllib.parse import urlencode + +import httpx +from yarl import URL + +from configs import dify_config + +logger = logging.getLogger(__name__) + +creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL)) + + +def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str: + url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload") + response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30) + response.raise_for_status() + data = response.json() + claim_code = data.get("data", {}).get("claim_code") + if not claim_code: + raise ValueError("Creators Platform did not return a valid claim_code") + return claim_code + + +def get_redirect_url(user_account_id: str, claim_code: str) -> str: + base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/") + params: dict[str, str] = {"dsl_claim_code": claim_code} + client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "") + if client_id: + from services.oauth_server import OAuthServerService + + oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id) + params["oauth_code"] = oauth_code + return f"{base_url}?{urlencode(params)}" diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 00fcfe0b80..10d79a8239 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -1,6 +1,7 @@ import json from enum import StrEnum from json import JSONDecodeError +from typing import Any from extensions.ext_redis import redis_client @@ -15,7 +16,7 @@ class ProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - def get(self) -> dict | None: + def get(self) -> dict[str, Any] | None: """ Get cached model provider credentials. @@ -33,7 +34,7 @@ class ProviderCredentialsCache: else: return None - def set(self, credentials: dict): + def set(self, credentials: dict[str, Any]): """ Cache model provider credentials. diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index a1e782a094..f169f247cf 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,14 +2,13 @@ import logging import secrets from typing import cast -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeBadRequestError -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel from models.provider import ProviderType logger = logging.getLogger(__name__) diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index ffb5148386..9f167ca49c 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC): """Generate cache key based on subclass implementation""" pass - def get(self) -> dict | None: + def get(self) -> dict[str, Any] | None: """Get cached provider credentials""" cached_credentials = redis_client.get(self.cache_key) if cached_credentials: @@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): class NoOpProviderCredentialCache: """No-op provider credential cache""" - def get(self) -> dict | None: + def get(self) -> dict[str, Any] | None: """Get cached provider credentials""" return None diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index e38592bb7b..91e92712b7 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -12,6 +12,7 @@ from pydantic import TypeAdapter, ValidationError from configs import dify_config from core.helper.http_client_pooling import get_pooled_http_client from core.tools.errors import ToolSSRFError +from graphon.http.response import HttpResponse logger = logging.getLogger(__name__) @@ -267,4 +268,47 @@ class SSRFProxy: return patch(url=url, max_retries=max_retries, **kwargs) +def _to_graphon_http_response(response: httpx.Response) -> HttpResponse: + """Convert an ``httpx`` response into Graphon's transport-agnostic wrapper.""" + return HttpResponse( + status_code=response.status_code, + headers=dict(response.headers), + content=response.content, + url=str(response.url) if response.url else None, + reason_phrase=response.reason_phrase, + fallback_text=response.text, + ) + + +class GraphonSSRFProxy: + """Adapter exposing SSRF helpers behind Graphon's ``HttpClientProtocol``.""" + + @property + def max_retries_exceeded_error(self) -> type[Exception]: + return max_retries_exceeded_error + + @property + def request_error(self) -> type[Exception]: + return request_error + + def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(get(url=url, max_retries=max_retries, **kwargs)) + + def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(head(url=url, max_retries=max_retries, **kwargs)) + + def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(post(url=url, max_retries=max_retries, **kwargs)) + + def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(put(url=url, max_retries=max_retries, **kwargs)) + + def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(delete(url=url, max_retries=max_retries, **kwargs)) + + def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> HttpResponse: + return _to_graphon_http_response(patch(url=url, max_retries=max_retries, **kwargs)) + + ssrf_proxy = SSRFProxy() +graphon_ssrf_proxy = GraphonSSRFProxy() diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index 54674d4ff6..bf5bf9af03 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -1,6 +1,7 @@ import json from enum import StrEnum from json import JSONDecodeError +from typing import Any from extensions.ext_redis import redis_client @@ -18,7 +19,7 @@ class ToolParameterCache: f":identity_id:{identity_id}" ) - def get(self) -> dict | None: + def get(self) -> dict[str, Any] | None: """ Get cached model provider credentials. @@ -36,7 +37,7 @@ class ToolParameterCache: else: return None - def set(self, parameters: dict): + def set(self, parameters: dict[str, Any]): """Cache model provider credentials.""" redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 60f5434bc1..8bcb899b23 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,10 +1,12 @@ +from typing import Any + from flask import Flask -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel +from graphon.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): @@ -28,7 +30,7 @@ class FreeHostingQuota(HostingQuota): class HostingProvider(BaseModel): enabled: bool = False - credentials: dict | None = None + credentials: dict[str, Any] | None = None quota_unit: QuotaUnit | None = None quotas: list[HostingQuota] = [] diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index b8d5ca2f50..b6e33396d1 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,7 +9,6 @@ from collections.abc import Mapping from typing import Any from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, func, select, update from sqlalchemy.orm.exc import ObjectDeletedError @@ -35,6 +34,7 @@ from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models import Account @@ -735,7 +735,9 @@ class IndexingRunner: @staticmethod def _update_document_index_status( - document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None + document_id: str, + after_indexing_status: IndexingStatus, + extra_update_params: Mapping[Any, Any] | None = None, ): """ Update the document indexing status. @@ -762,7 +764,7 @@ class IndexingRunner: db.session.commit() @staticmethod - def _update_segments_by_document(dataset_document_id: str, update_params: dict): + def _update_segments_by_document(dataset_document_id: str, update_params: Mapping[Any, Any]): """ Update the document segment by document id. """ diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index aa258c9f89..af2611bb0b 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -2,14 +2,9 @@ import json import logging import re from collections.abc import Sequence -from typing import Protocol, TypedDict, cast +from typing import Any, NotRequired, Protocol, TypedDict, cast import json_repair -from graphon.enums import WorkflowNodeExecutionMetadataKey -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from core.app.app_config.entities import ModelConfig @@ -23,8 +18,6 @@ from core.llm_generator.prompts import ( LLM_MODIFY_CODE_SYSTEM, LLM_MODIFY_PROMPT_SYSTEM, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, - SUGGESTED_QUESTIONS_MAX_TOKENS, - SUGGESTED_QUESTIONS_TEMPERATURE, SYSTEM_STRUCTURED_OUTPUT_GENERATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) @@ -35,12 +28,47 @@ from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models import App, Message, WorkflowNodeExecutionModel from models.workflow import Workflow logger = logging.getLogger(__name__) +class SuggestedQuestionsModelConfig(TypedDict): + provider: str + name: str + completion_params: NotRequired[dict[str, object]] + + +def _normalize_completion_params(completion_params: dict[str, object]) -> tuple[dict[str, object], list[str]]: + """ + Normalize raw completion params into invocation parameters and stop sequences. + + This mirrors the app-model access path by separating ``stop`` from provider + parameters before invocation, then drops non-positive token limits because + some plugin-backed models reject ``0`` after mapping ``max_tokens`` to their + provider-specific output-token field. + """ + normalized_parameters = dict(completion_params) + stop_value = normalized_parameters.pop("stop", []) + if isinstance(stop_value, list) and all(isinstance(item, str) for item in stop_value): + stop = stop_value + else: + stop = [] + + for token_limit_key in ("max_tokens", "max_output_tokens"): + token_limit = normalized_parameters.get(token_limit_key) + if isinstance(token_limit, int | float) and token_limit <= 0: + normalized_parameters.pop(token_limit_key, None) + + return normalized_parameters, stop + + class WorkflowServiceInterface(Protocol): def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None: pass @@ -123,8 +151,15 @@ class LLMGenerator: return name @classmethod - def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]: - output_parser = SuggestedQuestionsAfterAnswerOutputParser() + def generate_suggested_questions_after_answer( + cls, + tenant_id: str, + histories: str, + *, + instruction_prompt: str | None = None, + model_config: object | None = None, + ) -> Sequence[str]: + output_parser = SuggestedQuestionsAfterAnswerOutputParser(instruction_prompt=instruction_prompt) format_instructions = output_parser.get_format_instructions() prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n") @@ -133,10 +168,36 @@ class LLMGenerator: try: model_manager = ModelManager.for_tenant(tenant_id=tenant_id) - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - ) + configured_model = cast(dict[str, object], model_config) if isinstance(model_config, dict) else {} + provider = configured_model.get("provider") + model_name = configured_model.get("name") + use_configured_model = False + + if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name: + try: + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=provider, + model=model_name, + ) + use_configured_model = True + except Exception: + logger.warning( + "Failed to use configured suggested-questions model %s/%s, fallback to default model", + provider, + model_name, + exc_info=True, + ) + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) + else: + model_instance = model_manager.get_default_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + ) except InvokeAuthorizationError: return [] @@ -145,19 +206,29 @@ class LLMGenerator: questions: Sequence[str] = [] try: + configured_completion_params = configured_model.get("completion_params") + if use_configured_model and isinstance(configured_completion_params, dict): + model_parameters, stop = _normalize_completion_params(configured_completion_params) + elif use_configured_model: + model_parameters = {} + stop = [] + else: + # Default-model generation keeps the built-in suggested-questions tuning. + model_parameters = { + "max_tokens": 2560, + "temperature": 0.0, + } + stop = [] + response: LLMResult = model_instance.invoke_llm( prompt_messages=list(prompt_messages), - model_parameters={ - "max_tokens": SUGGESTED_QUESTIONS_MAX_TOKENS, - "temperature": SUGGESTED_QUESTIONS_TEMPERATURE, - }, + model_parameters=model_parameters, + stop=stop, stream=False, ) text_content = response.message.get_text_content() questions = output_parser.parse(text_content) if text_content else [] - except InvokeError: - questions = [] except Exception: logger.exception("Failed to generate suggested questions after answer") questions = [] @@ -533,7 +604,7 @@ class LLMGenerator: def __instruction_modify_common( tenant_id: str, model_config: ModelConfig, - last_run: dict | None, + last_run: dict[str, Any] | None, current: str | None, error_message: str | None, instruction: str, diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index a1710f11ac..d2e375626f 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -5,6 +5,11 @@ from enum import StrEnum from typing import Any, Literal, cast, overload import json_repair +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance from graphon.model_runtime.callbacks.base_callback import Callback from graphon.model_runtime.entities.llm_entities import ( LLMResult, @@ -21,11 +26,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule -from pydantic import TypeAdapter, ValidationError - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT -from core.model_manager import ModelInstance class ResponseFormat(StrEnum): @@ -200,9 +200,9 @@ def _handle_native_json_schema( provider: str, model_schema: AIModelEntity, structured_output_schema: Mapping, - model_parameters: dict, + model_parameters: dict[str, Any], rules: list[ParameterRule], -): +) -> dict[str, Any]: """ Handle structured output for models with native JSON schema support. @@ -224,7 +224,7 @@ def _handle_native_json_schema( return model_parameters -def _set_response_format(model_parameters: dict, rules: list): +def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None: """ Set the appropriate response format parameter based on model rules. @@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema return {"schema": processed_schema, "name": "llm_response"} -def remove_additional_properties(schema: dict): +def remove_additional_properties(schema: dict[str, Any]) -> None: """ Remove additionalProperties fields from JSON schema. Used for models like Gemini that don't support this property. @@ -349,7 +349,7 @@ def remove_additional_properties(schema: dict): remove_additional_properties(item) -def convert_boolean_to_string(schema: dict): +def convert_boolean_to_string(schema: dict[str, Any]) -> None: """ Convert boolean type specifications to string in JSON schema. diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index eec771181f..7ac340926d 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -3,17 +3,28 @@ import logging import re from collections.abc import Sequence -from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT +from core.llm_generator.prompts import DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT logger = logging.getLogger(__name__) class SuggestedQuestionsAfterAnswerOutputParser: + def __init__(self, instruction_prompt: str | None = None) -> None: + self._instruction_prompt = self._build_instruction_prompt(instruction_prompt) + + @staticmethod + def _build_instruction_prompt(instruction_prompt: str | None) -> str: + if not instruction_prompt or not instruction_prompt.strip(): + return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT + + return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].' + def get_format_instructions(self) -> str: - return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT + return self._instruction_prompt def parse(self, text: str) -> Sequence[str]: - action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) + stripped_text = text.strip() + action_match = re.search(r"\[.*?\]", stripped_text, re.DOTALL) questions: list[str] = [] if action_match is not None: try: @@ -23,4 +34,6 @@ class SuggestedQuestionsAfterAnswerOutputParser: else: if isinstance(json_obj, list): questions = [question for question in json_obj if isinstance(question, str)] + elif stripped_text: + logger.warning("Failed to find suggested questions payload array in text: %r", stripped_text[:200]) return questions diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index ee9a016c95..3c6f8c468a 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -1,5 +1,4 @@ # Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh -import os CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”. @@ -96,8 +95,8 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = ( ) -# Default prompt for suggested questions (can be overridden by environment variable) -_DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = ( +# Default prompt and model parameters for suggested questions. +DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "Please help me predict the three most likely questions that human would ask, " "and keep each question under 20 characters.\n" "MAKE SURE your output is the SAME language as the Assistant's latest response. " @@ -105,15 +104,6 @@ _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = ( '["question1","question2","question3"]\n' ) -# Environment variable override for suggested questions prompt -SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = os.getenv( - "SUGGESTED_QUESTIONS_PROMPT", _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT -) - -# Configurable LLM parameters for suggested questions (can be overridden by environment variables) -SUGGESTED_QUESTIONS_MAX_TOKENS = int(os.getenv("SUGGESTED_QUESTIONS_MAX_TOKENS", "256")) -SUGGESTED_QUESTIONS_TEMPERATURE = float(os.getenv("SUGGESTED_QUESTIONS_TEMPERATURE", "0")) - GENERATOR_QA_PROMPT = ( " The user will send a long text. Generate a Question and Answer pairs only using the knowledge" " in the long text. Please think step by step." diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py index 9baf6c4682..ae7be91c17 100644 --- a/api/core/logging/structured_formatter.py +++ b/api/core/logging/structured_formatter.py @@ -3,7 +3,7 @@ import logging import traceback from datetime import UTC, datetime -from typing import Any, TypedDict +from typing import Any, NotRequired, TypedDict import orjson @@ -16,6 +16,19 @@ class IdentityDict(TypedDict, total=False): user_type: str +class LogDict(TypedDict): + ts: str + severity: str + service: str + caller: str + message: str + trace_id: NotRequired[str] + span_id: NotRequired[str] + identity: NotRequired[IdentityDict] + attributes: NotRequired[dict[str, Any]] + stack_trace: NotRequired[str] + + class StructuredJSONFormatter(logging.Formatter): """ JSON log formatter following the specified schema: @@ -55,9 +68,9 @@ class StructuredJSONFormatter(logging.Formatter): return json.dumps(log_dict, default=str, ensure_ascii=False) - def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]: + def _build_log_dict(self, record: logging.LogRecord) -> LogDict: # Core fields - log_dict: dict[str, Any] = { + log_dict: LogDict = { "ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"), "severity": self.SEVERITY_MAP.get(record.levelno, "INFO"), "service": self._service_name, diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 5c3cd0d8f8..acba3e666b 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -303,9 +303,16 @@ class StreamableHTTPTransport: if response.status_code == 404: if isinstance(message.root, JSONRPCRequest): + error_msg = ( + f"MCP server URL returned 404 Not Found: {self.url} " + "— verify the server URL is correct and the server is running" + if is_initialization + else "Session terminated by server" + ) self._send_session_terminated_error( ctx.server_to_client_queue, message.root.id, + message=error_msg, ) return @@ -381,12 +388,13 @@ class StreamableHTTPTransport: self, server_to_client_queue: ServerToClientQueue, request_id: RequestId, + message: str = "Session terminated by server", ): """Send a session terminated error response.""" jsonrpc_error = JSONRPCError( jsonrpc="2.0", id=request_id, - error=ErrorData(code=32600, message="Session terminated by server"), + error=ErrorData(code=32600, message=message), ) session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) server_to_client_queue.put(session_message) diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 72171d1536..884610ca82 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -3,12 +3,11 @@ import logging from collections.abc import Mapping from typing import Any, NotRequired, TypedDict, cast -from graphon.variables.input_entities import VariableEntity, VariableEntityType - from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 7e35044176..7b5a7635f1 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,11 +4,11 @@ from contextlib import AbstractContextManager import httpx import httpx_sse -from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError +from graphon.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 5809d6f74a..d840ee213c 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,5 +1,14 @@ from collections.abc import Sequence +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from factories import file_factory from graphon.file import file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -10,15 +19,6 @@ from graphon.model_runtime.entities import ( UserPromptMessage, ) from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.file_access import DatabaseFileAccessController -from core.model_manager import ModelInstance -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from extensions.ext_database import db -from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 86d042de3e..457c888e33 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,20 +1,7 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence -from typing import IO, Any, Literal, Optional, Union, cast, overload - -from graphon.model_runtime.callbacks.base_callback import Callback -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from copy import deepcopy +from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from configs import dify_config from core.entities import PluginCredentialType @@ -25,9 +12,24 @@ from core.errors.error import ProviderTokenNotInitError from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_redis import redis_client +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from models.provider import ProviderType logger = logging.getLogger(__name__) +P = ParamSpec("P") +R = TypeVar("R") class ModelInstance: @@ -35,11 +37,13 @@ class ModelInstance: Model instance class. """ - def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): + def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None: self.provider_model_bundle = provider_model_bundle self.model_name = model self.provider = provider_model_bundle.configuration.provider.provider - self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + if credentials is None: + credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + self.credentials = credentials # Runtime LLM invocation fields. self.parameters: Mapping[str, Any] = {} self.stop: Sequence[str] = () @@ -77,7 +81,7 @@ class ModelInstance: @staticmethod def _get_load_balancing_manager( - configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict + configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any] ) -> Optional["LBModelManager"]: """ Get load balancing model credentials @@ -115,7 +119,7 @@ class ModelInstance: def invoke_llm( self, prompt_messages: Sequence[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True] = True, @@ -126,7 +130,7 @@ class ModelInstance: def invoke_llm( self, prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False] = False, @@ -137,7 +141,7 @@ class ModelInstance: def invoke_llm( self, prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, @@ -147,7 +151,7 @@ class ModelInstance: def invoke_llm( self, prompt_messages: Sequence[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: Sequence[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, @@ -169,7 +173,7 @@ class ModelInstance: return cast( Union[LLMResult, Generator], self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, prompt_messages=list(prompt_messages), @@ -194,7 +198,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, LargeLanguageModel): raise Exception("Model type instance is not LargeLanguageModel") return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, + self.model_type_instance.get_num_tokens, model=self.model_name, credentials=self.credentials, prompt_messages=list(prompt_messages), @@ -214,7 +218,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, texts=texts, @@ -236,7 +240,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, @@ -253,7 +257,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return self._round_robin_invoke( - function=self.model_type_instance.get_num_tokens, + self.model_type_instance.get_num_tokens, model=self.model_name, credentials=self.credentials, texts=texts, @@ -278,7 +282,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, query=query, @@ -306,7 +310,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, RerankModel): raise Exception("Model type instance is not RerankModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke_multimodal_rerank, + self.model_type_instance.invoke_multimodal_rerank, model=self.model_name, credentials=self.credentials, query=query, @@ -325,7 +329,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, ModerationModel): raise Exception("Model type instance is not ModerationModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, text=text, @@ -341,7 +345,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, Speech2TextModel): raise Exception("Model type instance is not Speech2TextModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, file=file, @@ -358,14 +362,14 @@ class ModelInstance: if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") return self._round_robin_invoke( - function=self.model_type_instance.invoke, + self.model_type_instance.invoke, model=self.model_name, credentials=self.credentials, content_text=content_text, voice=voice, ) - def _round_robin_invoke[**P, R](self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + def _round_robin_invoke(self, function: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: """ Round-robin invoke :param function: function to invoke @@ -433,8 +437,30 @@ class ModelInstance: class ModelManager: - def __init__(self, provider_manager: ProviderManager): + """Resolves :class:`ModelInstance` objects for a tenant and provider. + + When ``enable_credentials_cache`` is ``True``, resolved credentials for each + ``(tenant_id, provider, model_type, model)`` are stored in + ``_credentials_cache`` and reused. That can return **stale** credentials after + API keys or provider settings change, so a manager constructed with + ``enable_credentials_cache=True`` should not be kept for the lifetime of a + process or shared across unrelated work. Prefer a new manager per request, + workflow run, or similar bounded scope. + + The default is ``enable_credentials_cache=False``; in that mode the internal + credential cache is not populated, and each ``get_model_instance`` call + loads credentials from the current provider configuration. + """ + + def __init__( + self, + provider_manager: ProviderManager, + *, + enable_credentials_cache: bool = False, + ) -> None: self._provider_manager = provider_manager + self._credentials_cache: dict[tuple[str, str, str, str], Any] = {} + self._enable_credentials_cache = enable_credentials_cache @classmethod def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": @@ -462,8 +488,19 @@ class ModelManager: tenant_id=tenant_id, provider=provider, model_type=model_type ) - model_instance = ModelInstance(provider_model_bundle, model) - return model_instance + cred_cache_key = (tenant_id, provider, model_type.value, model) + + if cred_cache_key in self._credentials_cache: + return ModelInstance( + provider_model_bundle, + model, + deepcopy(self._credentials_cache[cred_cache_key]), + ) + + ret = ModelInstance(provider_model_bundle, model) + if self._enable_credentials_cache: + self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials) + return ret def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ @@ -528,7 +565,7 @@ class LBModelManager: model_type: ModelType, model: str, load_balancing_configs: list[ModelLoadBalancingConfiguration], - managed_credentials: dict | None = None, + managed_credentials: dict[str, Any] | None = None, ): """ Load balancing model manager diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 2d72b17a04..28165592fc 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field from sqlalchemy import select @@ -10,7 +12,7 @@ from models.api_based_extension import APIBasedExtension class ModerationInputParams(BaseModel): app_id: str = "" - inputs: dict = Field(default_factory=dict) + inputs: dict[str, Any] = Field(default_factory=dict) query: str = "" @@ -23,7 +25,7 @@ class ApiModeration(Moderation): name: str = "api" @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -41,7 +43,7 @@ class ApiModeration(Moderation): if not extension: raise ValueError("API-based Extension not found. Please check it again.") - def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" if self.config is None: @@ -73,7 +75,7 @@ class ApiModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) - def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict): + def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]): if self.config is None: raise ValueError("The config is not set.") extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 31dd0d5568..e090ee89ad 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from enum import StrEnum, auto +from typing import Any from pydantic import BaseModel, Field @@ -15,7 +16,7 @@ class ModerationInputsResult(BaseModel): flagged: bool = False action: ModerationAction preset_response: str = "" - inputs: dict = Field(default_factory=dict) + inputs: dict[str, Any] = Field(default_factory=dict) query: str = "" @@ -33,13 +34,13 @@ class Moderation(Extensible, ABC): module: ExtensionModule = ExtensionModule.MODERATION - def __init__(self, app_id: str, tenant_id: str, config: dict | None = None): + def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None): super().__init__(tenant_id, config) self.app_id = app_id @classmethod @abstractmethod - def validate_config(cls, tenant_id: str, config: dict) -> None: + def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None: """ Validate the incoming form config data. @@ -50,7 +51,7 @@ class Moderation(Extensible, ABC): raise NotImplementedError @abstractmethod - def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: """ Moderation for inputs. After the user inputs, this method will be called to perform sensitive content review @@ -75,7 +76,7 @@ class Moderation(Extensible, ABC): raise NotImplementedError @classmethod - def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool): + def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool): # inputs_config inputs_config = config.get("inputs_config") if not isinstance(inputs_config, dict): diff --git a/api/core/moderation/factory.py b/api/core/moderation/factory.py index c2c8be6d6d..c22306ac94 100644 --- a/api/core/moderation/factory.py +++ b/api/core/moderation/factory.py @@ -1,3 +1,5 @@ +from typing import Any + from core.extension.extensible import ExtensionModule from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult from extensions.ext_code_based_extension import code_based_extension @@ -6,12 +8,12 @@ from extensions.ext_code_based_extension import code_based_extension class ModerationFactory: __extension_instance: Moderation - def __init__(self, name: str, app_id: str, tenant_id: str, config: dict): + def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]): extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) self.__extension_instance = extension_class(app_id, tenant_id, config) @classmethod - def validate_config(cls, name: str, tenant_id: str, config: dict): + def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -24,7 +26,7 @@ class ModerationFactory: # FIXME: mypy error, try to fix it instead of using type: ignore extension_class.validate_config(tenant_id, config) # type: ignore - def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: """ Moderation for inputs. After the user inputs, this method will be called to perform sensitive content review diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 8d8d153743..7d80d3a53c 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -8,7 +8,7 @@ class KeywordsModeration(Moderation): name: str = "keywords" @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -28,7 +28,7 @@ class KeywordsModeration(Moderation): if len(keywords_row_len) > 100: raise ValueError("the number of rows for the keywords must be less than 100") - def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" if self.config is None: @@ -66,7 +66,7 @@ class KeywordsModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) - def _is_violated(self, inputs: dict, keywords_list: list) -> bool: + def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool: return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values()) def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool: diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index dd038c77f1..6e6e94502c 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,14 +1,15 @@ -from graphon.model_runtime.entities.model_entities import ModelType +from typing import Any from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult +from graphon.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): name: str = "openai_moderation" @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -18,7 +19,7 @@ class OpenAIModeration(Moderation): """ cls._validate_inputs_and_outputs_config(config, True) - def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: + def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" if self.config is None: @@ -49,7 +50,7 @@ class OpenAIModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) - def _is_violated(self, inputs: dict): + def _is_violated(self, inputs: dict[str, Any]): text = "\n".join(str(inputs.values())) model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) model_instance = model_manager.get_model_instance( diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index fda00ac3b9..d78ce90aa1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -1,8 +1,8 @@ from enum import StrEnum -from pydantic import BaseModel, ValidationInfo, field_validator +from pydantic import BaseModel -from core.ops.utils import validate_integer_id, validate_project_name, validate_url, validate_url_with_path +from core.ops.utils import validate_project_name, validate_url class TracingProviderEnum(StrEnum): @@ -52,220 +52,5 @@ class BaseTracingConfig(BaseModel): return validate_project_name(v, default_name) -class ArizeConfig(BaseTracingConfig): - """ - Model class for Arize tracing config. - """ - - api_key: str | None = None - space_id: str | None = None - project: str | None = None - endpoint: str = "https://otlp.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://otlp.arize.com") - - -class PhoenixConfig(BaseTracingConfig): - """ - Model class for Phoenix tracing config. - """ - - api_key: str | None = None - project: str | None = None - endpoint: str = "https://app.phoenix.arize.com" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "default") - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://app.phoenix.arize.com") - - -class LangfuseConfig(BaseTracingConfig): - """ - Model class for Langfuse tracing config. - """ - - public_key: str - secret_key: str - host: str = "https://api.langfuse.com" - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://api.langfuse.com") - - -class LangSmithConfig(BaseTracingConfig): - """ - Model class for Langsmith tracing config. - """ - - api_key: str - project: str - endpoint: str = "https://api.smith.langchain.com" - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # LangSmith only allows HTTPS - return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) - - -class OpikConfig(BaseTracingConfig): - """ - Model class for Opik tracing config. - """ - - api_key: str | None = None - project: str | None = None - workspace: str | None = None - url: str = "https://www.comet.com/opik/api/" - - @field_validator("project") - @classmethod - def project_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "Default Project") - - @field_validator("url") - @classmethod - def url_validator(cls, v, info: ValidationInfo): - return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") - - -class WeaveConfig(BaseTracingConfig): - """ - Model class for Weave tracing config. - """ - - api_key: str - entity: str | None = None - project: str - endpoint: str = "https://trace.wandb.ai" - host: str | None = None - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # Weave only allows HTTPS for endpoint - return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) - - @field_validator("host") - @classmethod - def host_validator(cls, v, info: ValidationInfo): - if v is not None and v.strip() != "": - return validate_url(v, v, allowed_schemes=("https", "http")) - return v - - -class AliyunConfig(BaseTracingConfig): - """ - Model class for Aliyun tracing config. - """ - - app_name: str = "dify_app" - license_key: str - endpoint: str - - @field_validator("app_name") - @classmethod - def app_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - @field_validator("license_key") - @classmethod - def license_key_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("License key cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - # aliyun uses two URL formats, which may include a URL path - return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") - - -class TencentConfig(BaseTracingConfig): - """ - Tencent APM tracing config - """ - - token: str - endpoint: str - service_name: str - - @field_validator("token") - @classmethod - def token_validator(cls, v, info: ValidationInfo): - if not v or v.strip() == "": - raise ValueError("Token cannot be empty") - return v - - @field_validator("endpoint") - @classmethod - def endpoint_validator(cls, v, info: ValidationInfo): - return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") - - @field_validator("service_name") - @classmethod - def service_name_validator(cls, v, info: ValidationInfo): - return cls.validate_project_field(v, "dify_app") - - -class MLflowConfig(BaseTracingConfig): - """ - Model class for MLflow tracing config. - """ - - tracking_uri: str = "http://localhost:5000" - experiment_id: str = "0" # Default experiment id in MLflow is 0 - username: str | None = None - password: str | None = None - - @field_validator("tracking_uri") - @classmethod - def tracking_uri_validator(cls, v, info: ValidationInfo): - if isinstance(v, str) and v.startswith("databricks"): - raise ValueError( - "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." - ) - return validate_url_with_path(v, "http://localhost:5000") - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - -class DatabricksConfig(BaseTracingConfig): - """ - Model class for Databricks (Databricks-managed MLflow) tracing config. - """ - - experiment_id: str - host: str - client_id: str | None = None - client_secret: str | None = None - personal_access_token: str | None = None - - @field_validator("experiment_id") - @classmethod - def experiment_id_validator(cls, v, info: ValidationInfo): - return validate_integer_id(v) - - OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index fd235faf80..e7ba6e502b 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -204,114 +204,117 @@ class TracingProviderConfigEntry(TypedDict): class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): def __getitem__(self, provider: str) -> TracingProviderConfigEntry: - match provider: - case TracingProviderEnum.LANGFUSE: - from core.ops.entities.config_entity import LangfuseConfig - from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace + try: + match provider: + case TracingProviderEnum.LANGFUSE: + from dify_trace_langfuse.config import LangfuseConfig + from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace - return { - "config_class": LangfuseConfig, - "secret_keys": ["public_key", "secret_key"], - "other_keys": ["host", "project_key"], - "trace_instance": LangFuseDataTrace, - } + return { + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, + } - case TracingProviderEnum.LANGSMITH: - from core.ops.entities.config_entity import LangSmithConfig - from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace + case TracingProviderEnum.LANGSMITH: + from dify_trace_langsmith.config import LangSmithConfig + from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace - return { - "config_class": LangSmithConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": LangSmithDataTrace, - } + return { + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + } - case TracingProviderEnum.OPIK: - from core.ops.entities.config_entity import OpikConfig - from core.ops.opik_trace.opik_trace import OpikDataTrace + case TracingProviderEnum.OPIK: + from dify_trace_opik.config import OpikConfig + from dify_trace_opik.opik_trace import OpikDataTrace - return { - "config_class": OpikConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "url", "workspace"], - "trace_instance": OpikDataTrace, - } + return { + "config_class": OpikConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "url", "workspace"], + "trace_instance": OpikDataTrace, + } - case TracingProviderEnum.WEAVE: - from core.ops.entities.config_entity import WeaveConfig - from core.ops.weave_trace.weave_trace import WeaveDataTrace + case TracingProviderEnum.WEAVE: + from dify_trace_weave.config import WeaveConfig + from dify_trace_weave.weave_trace import WeaveDataTrace - return { - "config_class": WeaveConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "entity", "endpoint", "host"], - "trace_instance": WeaveDataTrace, - } - case TracingProviderEnum.ARIZE: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import ArizeConfig + return { + "config_class": WeaveConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "entity", "endpoint", "host"], + "trace_instance": WeaveDataTrace, + } + case TracingProviderEnum.ARIZE: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import ArizeConfig - return { - "config_class": ArizeConfig, - "secret_keys": ["api_key", "space_id"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.PHOENIX: - from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace - from core.ops.entities.config_entity import PhoenixConfig + return { + "config_class": ArizeConfig, + "secret_keys": ["api_key", "space_id"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.PHOENIX: + from dify_trace_arize_phoenix.arize_phoenix_trace import ArizePhoenixDataTrace + from dify_trace_arize_phoenix.config import PhoenixConfig - return { - "config_class": PhoenixConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": ArizePhoenixDataTrace, - } - case TracingProviderEnum.ALIYUN: - from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace - from core.ops.entities.config_entity import AliyunConfig + return { + "config_class": PhoenixConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": ArizePhoenixDataTrace, + } + case TracingProviderEnum.ALIYUN: + from dify_trace_aliyun.aliyun_trace import AliyunDataTrace + from dify_trace_aliyun.config import AliyunConfig - return { - "config_class": AliyunConfig, - "secret_keys": ["license_key"], - "other_keys": ["endpoint", "app_name"], - "trace_instance": AliyunDataTrace, - } - case TracingProviderEnum.MLFLOW: - from core.ops.entities.config_entity import MLflowConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": AliyunConfig, + "secret_keys": ["license_key"], + "other_keys": ["endpoint", "app_name"], + "trace_instance": AliyunDataTrace, + } + case TracingProviderEnum.MLFLOW: + from dify_trace_mlflow.config import MLflowConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": MLflowConfig, - "secret_keys": ["password"], - "other_keys": ["tracking_uri", "experiment_id", "username"], - "trace_instance": MLflowDataTrace, - } - case TracingProviderEnum.DATABRICKS: - from core.ops.entities.config_entity import DatabricksConfig - from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace + return { + "config_class": MLflowConfig, + "secret_keys": ["password"], + "other_keys": ["tracking_uri", "experiment_id", "username"], + "trace_instance": MLflowDataTrace, + } + case TracingProviderEnum.DATABRICKS: + from dify_trace_mlflow.config import DatabricksConfig + from dify_trace_mlflow.mlflow_trace import MLflowDataTrace - return { - "config_class": DatabricksConfig, - "secret_keys": ["personal_access_token", "client_secret"], - "other_keys": ["host", "client_id", "experiment_id"], - "trace_instance": MLflowDataTrace, - } + return { + "config_class": DatabricksConfig, + "secret_keys": ["personal_access_token", "client_secret"], + "other_keys": ["host", "client_id", "experiment_id"], + "trace_instance": MLflowDataTrace, + } - case TracingProviderEnum.TENCENT: - from core.ops.entities.config_entity import TencentConfig - from core.ops.tencent_trace.tencent_trace import TencentDataTrace + case TracingProviderEnum.TENCENT: + from dify_trace_tencent.config import TencentConfig + from dify_trace_tencent.tencent_trace import TencentDataTrace - return { - "config_class": TencentConfig, - "secret_keys": ["token"], - "other_keys": ["endpoint", "service_name"], - "trace_instance": TencentDataTrace, - } + return { + "config_class": TencentConfig, + "secret_keys": ["token"], + "other_keys": ["endpoint", "service_name"], + "trace_instance": TencentDataTrace, + } - case _: - raise KeyError(f"Unsupported tracing provider: {provider}") + case _: + raise KeyError(f"Unsupported tracing provider: {provider}") + except ImportError: + raise ImportError(f"Provider {provider} is not installed.") provider_config_map = OpsTraceProviderConfigMap() @@ -324,7 +327,7 @@ class OpsTraceManager: @classmethod def encrypt_tracing_config( - cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None + cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any], current_trace_config=None ): """ Encrypt tracing config. @@ -363,7 +366,7 @@ class OpsTraceManager: return encrypted_config.model_dump() @classmethod - def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict): + def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any]): """ Decrypt tracing config :param tenant_id: tenant id @@ -408,7 +411,7 @@ class OpsTraceManager: return dict(decrypted_config) @classmethod - def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict): + def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict[str, Any]): """ Decrypt tracing config :param tracing_provider: tracing provider @@ -581,7 +584,7 @@ class OpsTraceManager: return app_trace_config @staticmethod - def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str): + def check_trace_config_is_effective(tracing_config: dict[str, Any], tracing_provider: str): """ Check trace config is effective :param tracing_config: tracing config @@ -596,7 +599,7 @@ class OpsTraceManager: return trace_instance(config).api_check() @staticmethod - def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): + def get_trace_config_project_key(tracing_config: dict[str, Any], tracing_provider: str): """ get trace config is project key :param tracing_config: tracing config @@ -611,7 +614,7 @@ class OpsTraceManager: return trace_instance(config).get_project_key() @staticmethod - def get_trace_config_project_url(tracing_config: dict, tracing_provider: str): + def get_trace_config_project_url(tracing_config: dict[str, Any], tracing_provider: str): """ get trace config is project key :param tracing_config: tracing config @@ -1322,8 +1325,8 @@ class TraceTask: error=error, ) - def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict: - node_data: dict = kwargs.get("node_execution_data", {}) + def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict[str, Any]: + node_data: dict[str, Any] = kwargs.get("node_execution_data", {}) if not node_data: return {} @@ -1431,7 +1434,7 @@ class TraceTask: return node_trace return DraftNodeExecutionTrace(**node_trace.model_dump()) - def _extract_streaming_metrics(self, message_data) -> dict: + def _extract_streaming_metrics(self, message_data) -> dict[str, Any]: if not message_data.message_metadata: return {} diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index e2d2be92cb..c76cb865c3 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Any, Union, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -207,7 +207,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): ) @classmethod - def _get_user(cls, user_id: str) -> Union[EndUser, Account]: + def _get_user(cls, user_id: str) -> EndUser | Account: """ get the user by user id """ diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index c715b9171c..c92438960a 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -1,20 +1,7 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator - -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelType +from typing import Any from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output @@ -32,6 +19,19 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Tenant @@ -226,7 +226,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): # invoke model response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice) - def handle() -> Generator[dict, None, None]: + def handle() -> Generator[dict[str, Any], None, None]: for chunk in response: yield {"result": hexlify(chunk).decode("utf-8")} diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 9478997494..9550e49992 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,3 +1,4 @@ +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from graphon.enums import BuiltinNodeTypes from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig from graphon.nodes.parameter_extractor.entities import ( @@ -8,8 +9,6 @@ from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) - -from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from services.workflow_service import WorkflowService diff --git a/api/core/plugin/entities/endpoint.py b/api/core/plugin/entities/endpoint.py index e5bca140f8..6419963668 100644 --- a/api/core/plugin/entities/endpoint.py +++ b/api/core/plugin/entities/endpoint.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any from pydantic import BaseModel, Field, model_validator @@ -31,7 +32,7 @@ class EndpointEntity(BasePluginEntity): entity of an endpoint """ - settings: dict + settings: dict[str, Any] tenant_id: str plugin_id: str expired_at: datetime diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 2177e8af90..03398873e3 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,10 +1,12 @@ -from graphon.model_runtime.entities.provider_entities import ProviderEntity +from typing import Any + from pydantic import BaseModel, Field, computed_field, model_validator from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): @@ -40,7 +42,7 @@ class MarketplacePluginDeclaration(BaseModel): @model_validator(mode="before") @classmethod - def transform_declaration(cls, data: dict): + def transform_declaration(cls, data: dict[str, Any]) -> dict[str, Any]: if "endpoint" in data and not data["endpoint"]: del data["endpoint"] if "model" in data and not data["model"]: diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index b095b4998d..89e0e8881c 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -3,7 +3,6 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any -from graphon.model_runtime.entities.provider_entities import ProviderEntity from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -14,6 +13,7 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): @@ -123,7 +123,7 @@ class PluginDeclaration(BaseModel): @model_validator(mode="before") @classmethod - def validate_category(cls, values: dict): + def validate_category(cls, values: dict[str, Any]) -> dict[str, Any]: # auto detect category if values.get("tool"): values["category"] = PluginCategory.Tool diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index b57180690e..257638ad77 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,8 +6,6 @@ from datetime import datetime from enum import StrEnum from typing import Any -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin @@ -18,6 +16,8 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity class PluginDaemonBasicResponse[T: BaseModel | dict | list | bool | str](BaseModel): @@ -73,7 +73,7 @@ class PluginBasicBooleanResponse(BaseModel): """ result: bool - credentials: dict | None = None + credentials: dict[str, Any] | None = None class PluginModelSchemaEntity(BaseModel): diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 059f3fa9be..1474883204 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -4,6 +4,10 @@ from collections.abc import Mapping from typing import Any, Literal from flask import Response +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from core.entities.provider_entities import BasicProviderConfig +from core.plugin.utils.http_parser import deserialize_response from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -21,10 +25,6 @@ from graphon.nodes.parameter_extractor.entities import ( from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from core.entities.provider_entities import BasicProviderConfig -from core.plugin.utils.http_parser import deserialize_response class InvokeCredentials(BaseModel): @@ -49,7 +49,7 @@ class RequestInvokeTool(BaseModel): tool_type: Literal["builtin", "workflow", "api", "mcp"] provider: str tool: str - tool_parameters: dict + tool_parameters: dict[str, Any] credential_id: str | None = None @@ -209,7 +209,7 @@ class RequestInvokeEncrypt(BaseModel): opt: Literal["encrypt", "decrypt", "clear"] namespace: Literal["endpoint"] identity: str - data: dict = Field(default_factory=dict) + data: dict[str, Any] = Field(default_factory=dict) config: list[BasicProviderConfig] = Field(default_factory=list) diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7f36560b49..9ee8469892 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -5,14 +5,6 @@ from collections.abc import Callable, Generator from typing import Any, cast import httpx -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from yarl import URL @@ -37,6 +29,14 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index ce1ef71494..56c08addba 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -26,7 +26,7 @@ class PluginDatasourceManager(BasePluginClient): Fetch datasource providers for the given tenant. """ - def transformer(json_response: dict[str, Any]) -> dict: + def transformer(json_response: dict[str, Any]) -> dict[str, Any]: if json_response.get("data"): for provider in json_response.get("data", []): declaration = provider.get("declaration", {}) or {} @@ -68,7 +68,7 @@ class PluginDatasourceManager(BasePluginClient): Fetch datasource providers for the given tenant. """ - def transformer(json_response: dict[str, Any]) -> dict: + def transformer(json_response: dict[str, Any]) -> dict[str, Any]: if json_response.get("data"): for provider in json_response.get("data", []): declaration = provider.get("declaration", {}) or {} @@ -110,7 +110,7 @@ class PluginDatasourceManager(BasePluginClient): tool_provider_id = DatasourceProviderID(provider_id) - def transformer(json_response: dict[str, Any]) -> dict: + def transformer(json_response: dict[str, Any]) -> dict[str, Any]: data = json_response.get("data") if data: for datasource in data.get("declaration", {}).get("datasources", []): diff --git a/api/core/plugin/impl/endpoint.py b/api/core/plugin/impl/endpoint.py index 2db5185a2c..b335b42763 100644 --- a/api/core/plugin/impl/endpoint.py +++ b/api/core/plugin/impl/endpoint.py @@ -1,3 +1,5 @@ +from typing import Any + from core.plugin.entities.endpoint import EndpointEntityWithInstance from core.plugin.impl.base import BasePluginClient from core.plugin.impl.exc import PluginDaemonInternalServerError @@ -5,7 +7,12 @@ from core.plugin.impl.exc import PluginDaemonInternalServerError class PluginEndpointClient(BasePluginClient): def create_endpoint( - self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict + self, + tenant_id: str, + user_id: str, + plugin_unique_identifier: str, + name: str, + settings: dict[str, Any], ) -> bool: """ Create an endpoint for the given plugin. @@ -49,7 +56,9 @@ class PluginEndpointClient(BasePluginClient): params={"plugin_id": plugin_id, "page": page, "page_size": page_size}, ) - def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict): + def update_endpoint( + self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any] + ) -> bool: """ Update the settings of the given endpoint. """ diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 1e38c24717..47608bdfa6 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -2,13 +2,6 @@ import binascii from collections.abc import Generator, Sequence from typing import IO, Any -from graphon.model_runtime.entities.llm_entities import LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult -from graphon.model_runtime.utils.encoders import jsonable_encoder - from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, @@ -20,6 +13,12 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): @@ -50,7 +49,7 @@ class PluginModelClient(BasePluginClient): provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], ) -> AIModelEntity | None: """ Get model schema @@ -80,7 +79,7 @@ class PluginModelClient(BasePluginClient): return None def validate_provider_credentials( - self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict + self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict[str, Any] ) -> bool: """ validate the credentials of the provider @@ -118,7 +117,7 @@ class PluginModelClient(BasePluginClient): provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], ) -> bool: """ validate the credentials of the provider @@ -157,9 +156,9 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, + model_parameters: dict[str, Any] | None = None, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, @@ -206,7 +205,7 @@ class PluginModelClient(BasePluginClient): provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], prompt_messages: list[PromptMessage], tools: list[PromptMessageTool] | None = None, ) -> int: @@ -248,7 +247,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], texts: list[str], input_type: str, ) -> EmbeddingResult: @@ -290,7 +289,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], documents: list[dict], input_type: str, ) -> EmbeddingResult: @@ -332,7 +331,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], texts: list[str], ) -> list[int]: """ @@ -372,7 +371,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], query: str, docs: list[str], score_threshold: float | None = None, @@ -418,7 +417,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], query: MultimodalRerankInput, docs: list[MultimodalRerankInput], score_threshold: float | None = None, @@ -463,7 +462,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], content_text: str, voice: str, ) -> Generator[bytes, None, None]: @@ -508,7 +507,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], language: str | None = None, ): """ @@ -552,7 +551,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], file: IO[bytes], ) -> str: """ @@ -592,7 +591,7 @@ class PluginModelClient(BasePluginClient): plugin_id: str, provider: str, model: str, - credentials: dict, + credentials: dict[str, Any], text: str, ) -> bool: """ diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 22c846b6de..4e66d58b5e 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -6,13 +6,6 @@ from collections.abc import Generator, Iterable, Sequence from threading import Lock from typing import IO, Any, Union -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.runtime import ModelRuntime from pydantic import ValidationError from redis import RedisError @@ -21,6 +14,13 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.runtime import ModelRuntime from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) @@ -66,15 +66,15 @@ class PluginModelRuntime(ModelRuntime): if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") file_name = ( - provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + provider_schema.icon_small.zh_hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_us ) elif icon_type.lower() == "icon_small_dark": if not provider_schema.icon_small_dark: raise ValueError(f"Provider {provider} does not have small dark icon.") file_name = ( - provider_schema.icon_small_dark.zh_Hans + provider_schema.icon_small_dark.zh_hans if lang.lower() == "zh_hans" - else provider_schema.icon_small_dark.en_US + else provider_schema.icon_small_dark.en_us ) else: raise ValueError(f"Unsupported icon type: {icon_type}.") diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index 4b29a6fc56..35abd2ae8c 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -2,9 +2,8 @@ from __future__ import annotations from typing import TYPE_CHECKING -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory - from core.plugin.impl.model import PluginModelClient +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory if TYPE_CHECKING: from core.model_manager import ModelManager diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index c75c30a98a..8a7175bb51 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import Any from requests import HTTPError @@ -263,7 +264,7 @@ class PluginInstaller(BasePluginClient): original_plugin_unique_identifier: str, new_plugin_unique_identifier: str, source: PluginInstallationSource, - meta: dict, + meta: dict[str, Any], ) -> PluginInstallTaskStartResponse: """ Upgrade a plugin. diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 90350f8400..12d8e282b2 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,8 +1,7 @@ from typing import Any -from graphon.file import File - from core.tools.entities.tool_entities import ToolSelector +from graphon.file import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/plugin/utils/http_parser.py b/api/core/plugin/utils/http_parser.py index ce943929be..af0ff10bfb 100644 --- a/api/core/plugin/utils/http_parser.py +++ b/api/core/plugin/utils/http_parser.py @@ -151,6 +151,12 @@ def deserialize_response(raw_data: bytes) -> Response: response = Response(response=body, status=status_code) + # Replace Flask's default headers (e.g. Content-Type, Content-Length) with the + # parsed ones so we faithfully reproduce the original response. Use Headers.add + # rather than dict-style assignment so that repeated headers such as Set-Cookie + # (and any other multi-valued header per RFC 9110) are preserved instead of + # being overwritten. + response.headers.clear() for line in lines[1:]: if not line: continue @@ -158,6 +164,6 @@ def deserialize_response(raw_data: bytes) -> Response: if ":" not in line_str: continue name, value = line_str.split(":", 1) - response.headers[name] = value.strip() + response.headers.add(name, value.strip()) return response diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 19b5e9223a..24e05ef865 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,13 @@ from collections.abc import Mapping, Sequence from typing import cast +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, file_manager from graphon.model_runtime.entities import ( AssistantPromptMessage, @@ -13,14 +20,6 @@ from graphon.model_runtime.entities import ( from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from graphon.runtime import VariablePool -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser - class AdvancedPromptTransform(PromptTransform): """ diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 9be70199b7..7c6280fe93 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,17 +1,16 @@ from typing import cast -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.prompt_transform import PromptTransform +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 4539ae9f11..6ff2f44cdc 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,12 +1,11 @@ from typing import Any -from graphon.model_runtime.entities.message_entities import PromptMessage -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index c706353ffe..1665bdeb52 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -2,8 +2,14 @@ import json import os from collections.abc import Mapping, Sequence from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import file_manager from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, @@ -13,13 +19,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: @@ -34,6 +33,13 @@ class ModelMode(StrEnum): prompt_file_contents: dict[str, Any] = {} +class PromptTemplateConfigDict(TypedDict): + prompt_template: PromptTemplateParser + custom_variable_keys: list[str] + special_variable_keys: list[str] + prompt_rules: dict[str, Any] + + class SimplePromptTransform(PromptTransform): """ Simple Prompt Transform for Chatbot App Basic Mode. @@ -89,11 +95,11 @@ class SimplePromptTransform(PromptTransform): app_mode: AppMode, model_config: ModelConfigWithCredentialsEntity, pre_prompt: str, - inputs: dict, + inputs: dict[str, Any], query: str | None = None, context: str | None = None, histories: str | None = None, - ) -> tuple[str, dict]: + ) -> tuple[str, dict[str, Any]]: # get prompt template prompt_template_config = self.get_prompt_template( app_mode=app_mode, @@ -105,18 +111,13 @@ class SimplePromptTransform(PromptTransform): with_memory_prompt=histories is not None, ) - custom_variable_keys_obj = prompt_template_config["custom_variable_keys"] - special_variable_keys_obj = prompt_template_config["special_variable_keys"] + custom_variable_keys = prompt_template_config["custom_variable_keys"] + if not isinstance(custom_variable_keys, list): + raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys)}") - # Type check for custom_variable_keys - if not isinstance(custom_variable_keys_obj, list): - raise TypeError(f"Expected list for custom_variable_keys, got {type(custom_variable_keys_obj)}") - custom_variable_keys = cast(list[str], custom_variable_keys_obj) - - # Type check for special_variable_keys - if not isinstance(special_variable_keys_obj, list): - raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys_obj)}") - special_variable_keys = cast(list[str], special_variable_keys_obj) + special_variable_keys = prompt_template_config["special_variable_keys"] + if not isinstance(special_variable_keys, list): + raise TypeError(f"Expected list for special_variable_keys, got {type(special_variable_keys)}") variables = {k: inputs[k] for k in custom_variable_keys if k in inputs} @@ -150,7 +151,7 @@ class SimplePromptTransform(PromptTransform): has_context: bool, query_in_prompt: bool, with_memory_prompt: bool = False, - ) -> dict[str, object]: + ) -> PromptTemplateConfigDict: prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) custom_variable_keys: list[str] = [] @@ -173,18 +174,19 @@ class SimplePromptTransform(PromptTransform): prompt += prompt_rules.get("query_prompt", "{{#query#}}") special_variable_keys.append("#query#") - return { + result: PromptTemplateConfigDict = { "prompt_template": PromptTemplateParser(template=prompt), "custom_variable_keys": custom_variable_keys, "special_variable_keys": special_variable_keys, "prompt_rules": prompt_rules, } + return result def _get_chat_model_prompt_messages( self, app_mode: AppMode, pre_prompt: str, - inputs: dict, + inputs: dict[str, Any], query: str, context: str | None, files: Sequence["File"], @@ -231,7 +233,7 @@ class SimplePromptTransform(PromptTransform): self, app_mode: AppMode, pre_prompt: str, - inputs: dict, + inputs: dict[str, Any], query: str, context: str | None, files: Sequence["File"], @@ -310,7 +312,7 @@ class SimplePromptTransform(PromptTransform): return prompt_message - def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str): + def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict[str, Any]: """ Get simple prompt rule. :param app_mode: app mode @@ -322,7 +324,7 @@ class SimplePromptTransform(PromptTransform): # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: - return cast(dict, prompt_file_contents[prompt_file_name]) + return cast(dict[str, Any], prompt_file_contents[prompt_file_name]) # Get the absolute path of the subdirectory prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") @@ -335,7 +337,7 @@ class SimplePromptTransform(PromptTransform): # Store the content of the prompt file prompt_file_contents[prompt_file_name] = content - return cast(dict, content) + return cast(dict[str, Any], content) def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index dbda749925..ba76eb0c4e 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from typing import Any, cast +from core.prompt.simple_prompt_transform import ModelMode from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -11,8 +12,6 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, ) -from core.prompt.simple_prompt_transform import ModelMode - class PromptMessageUtil: @staticmethod diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index e3b3f83c20..b290ae456e 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -6,20 +6,12 @@ from collections.abc import Sequence from json import JSONDecodeError from typing import TYPE_CHECKING, Any -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session from configs import dify_config +from core.db.session_factory import session_factory from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( @@ -41,6 +33,14 @@ from core.helper.position_helper import is_filtered from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ( LoadBalancingModelConfig, Provider, @@ -70,12 +70,32 @@ class ProviderManager: Request-bound managers may carry caller identity in that runtime, and the resulting ``ProviderConfiguration`` objects must reuse it for downstream model-type and schema lookups. + + Configuration assembly is cached per manager instance so call chains that + share one request-scoped manager can reuse the same provider graph instead + of rebuilding it for every lookup. Call ``clear_configurations_cache()`` + when a long-lived manager needs to observe writes performed within the same + instance scope. """ + decoding_rsa_key: Any | None + decoding_cipher_rsa: Any | None + _model_runtime: ModelRuntime + _configurations_cache: dict[str, ProviderConfigurations] + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None self._model_runtime = model_runtime + self._configurations_cache = {} + + def clear_configurations_cache(self, tenant_id: str | None = None) -> None: + """Drop assembled provider configurations cached on this manager instance.""" + if tenant_id is None: + self._configurations_cache.clear() + return + + self._configurations_cache.pop(tenant_id, None) def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -114,6 +134,10 @@ class ProviderManager: :param tenant_id: :return: """ + cached_configurations = self._configurations_cache.get(tenant_id) + if cached_configurations is not None: + return cached_configurations + # Get all provider records of the workspace provider_name_to_provider_records_dict = self._get_all_providers(tenant_id) @@ -273,6 +297,8 @@ class ProviderManager: provider_configurations[str(provider_id_entity)] = provider_configuration + self._configurations_cache[tenant_id] = provider_configurations + # Return the encapsulated object return provider_configurations @@ -419,7 +445,7 @@ class ProviderManager: @staticmethod def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: provider_name_to_provider_records_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) providers = session.scalars(stmt) for provider in providers: @@ -436,7 +462,7 @@ class ProviderManager: :return: """ provider_name_to_provider_model_records_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) provider_models = session.scalars(stmt) for provider_model in provider_models: @@ -452,7 +478,7 @@ class ProviderManager: :return: """ provider_name_to_preferred_provider_type_records_dict = {} - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) preferred_provider_types = session.scalars(stmt) provider_name_to_preferred_provider_type_records_dict = { @@ -470,7 +496,7 @@ class ProviderManager: :return: """ provider_name_to_provider_model_settings_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) provider_model_settings = session.scalars(stmt) for provider_model_setting in provider_model_settings: @@ -488,7 +514,7 @@ class ProviderManager: :return: """ provider_name_to_provider_model_credentials_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id) provider_model_credentials = session.scalars(stmt) for provider_model_credential in provider_model_credentials: @@ -518,7 +544,7 @@ class ProviderManager: return {} provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) provider_load_balancing_configs = session.scalars(stmt) for provider_load_balancing_config in provider_load_balancing_configs: @@ -552,7 +578,7 @@ class ProviderManager: :param provider_name: provider name :return: """ - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = ( select(ProviderCredential) .where( @@ -582,7 +608,7 @@ class ProviderManager: :param model_type: model type :return: """ - with Session(db.engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: stmt = ( select(ProviderModelCredential) .where( @@ -856,7 +882,7 @@ class ProviderManager: secret_variables: list[str], cache_type: ProviderCredentialsCacheType, is_provider: bool = False, - ) -> dict: + ) -> dict[str, Any]: """Get and decrypt credentials with caching.""" credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 9ce91f52ff..ca530748ed 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,8 +1,5 @@ from typing import TypedDict -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError - from core.model_manager import ModelInstance, ModelManager from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType @@ -11,6 +8,8 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class RerankingModelDict(TypedDict): diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index ed264878d3..392af351b6 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -139,8 +139,10 @@ class Jieba(BaseKeyword): "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, } dataset_keyword_table = self.dataset.dataset_keyword_table - keyword_data_source_type = dataset_keyword_table.data_source_type + keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file" if keyword_data_source_type == "database": + if dataset_keyword_table is None: + return dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict) db.session.commit() else: @@ -154,7 +156,8 @@ class Jieba(BaseKeyword): if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return dict(keyword_table_dict["__data__"]["table"]) + data: Any = keyword_table_dict["__data__"] + return dict(data["table"]) else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 84f35c25f8..2af8238cc4 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -1,4 +1,5 @@ import re +from collections.abc import Callable from operator import itemgetter from typing import cast @@ -80,12 +81,14 @@ class JiebaKeywordTableHandler: def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs): # Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable. - top_k = kwargs.pop("topK", top_k) + top_k = cast(int | None, kwargs.pop("topK", top_k)) + if top_k is None: + top_k = 20 cut = getattr(jieba, "cut", None) if self._lcut: tokens = self._lcut(sentence) elif callable(cut): - tokens = list(cut(sentence)) + tokens = list(cast(Callable[[str], list[str]], cut)(sentence)) else: tokens = re.findall(r"\w+", sentence) @@ -106,9 +109,9 @@ class JiebaKeywordTableHandler: """Extract keywords with JIEBA tfidf.""" keywords = self._tfidf.extract_tags( sentence=text, - topK=max_keywords_per_chunk, + topK=max_keywords_per_chunk or 10, ) - # jieba.analyse.extract_tags returns list[Any] when withFlag is False by default. + # jieba.analyse.extract_tags returns an untyped list when withFlag is False by default. keywords = cast(list[str], keywords) return set(self._expand_tokens_with_subtokens(set(keywords))) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index c1654ac130..b985ebbe1d 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired, TypedDict from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session, load_only @@ -24,6 +23,7 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( ChildChunk, Dataset, @@ -158,7 +158,7 @@ class RetrievalService: ) if futures: - for future in concurrent.futures.as_completed(futures, timeout=3600): + for _ in concurrent.futures.as_completed(futures, timeout=3600): if exceptions: for f in futures: f.cancel() @@ -174,8 +174,8 @@ class RetrievalService: cls, dataset_id: str, query: str, - external_retrieval_model: dict | None = None, - metadata_filtering_conditions: dict | None = None, + external_retrieval_model: dict[str, Any] | None = None, + metadata_filtering_conditions: dict[str, Any] | None = None, ): stmt = select(Dataset).where(Dataset.id == dataset_id) dataset = db.session.scalar(stmt) @@ -195,15 +195,33 @@ class RetrievalService: ) return all_documents + @classmethod + def _filter_documents_by_vector_score_threshold( + cls, documents: list[Document], score_threshold: float | None + ) -> list[Document]: + """Keep documents whose stored retrieval score meets the threshold. + + Used when hybrid search skips early vector thresholding but no rerank + runner applies a threshold afterward (same rule as ``calculate_vector_score``). + """ + if score_threshold is None: + return documents + return [ + document + for document in documents + if document.metadata and document.metadata.get("score", 0) >= score_threshold + ] + @classmethod def _deduplicate_documents(cls, documents: list[Document]) -> list[Document]: """Deduplicate documents in O(n) while preserving first-seen order. Rules: - - For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest - metadata["score"] among duplicates; if a later duplicate has no score, ignore it. - - For non-dify documents (or dify without doc_id): deduplicate by content key - (provider, page_content), keeping the first occurrence. + - If metadata["doc_id"] exists (any provider): deduplicate by (provider, doc_id) key; + keep the doc with the highest metadata["score"] among duplicates. If a later duplicate + has no score, ignore it. + - If metadata["doc_id"] is absent: deduplicate by content key (provider, page_content), + keeping the first occurrence. """ if not documents: return documents @@ -214,11 +232,10 @@ class RetrievalService: order: list[tuple] = [] for doc in documents: - is_dify = doc.provider == "dify" - doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None + doc_id = (doc.metadata or {}).get("doc_id") - if is_dify and doc_id: - key = ("dify", doc_id) + if doc_id: + key = (doc.provider or "dify", doc_id) if key not in chosen: chosen[key] = doc order.append(key) @@ -294,13 +311,20 @@ class RetrievalService: vector = Vector(dataset=dataset) documents = [] + # Hybrid search merges keyword / full-text / vector hits and then reranks + # (weighted fusion or reranking model). Applying the user score threshold at + # vector retrieval time uses embedding similarity, which is not comparable to + # reranked or fused scores and incorrectly drops high-quality chunks (#35233). + embedding_score_threshold = ( + 0.0 if retrieval_method == RetrievalMethod.HYBRID_SEARCH else score_threshold + ) if query_type == QueryType.TEXT_QUERY: documents.extend( vector.search_by_vector( query, search_type="similarity_score_threshold", top_k=top_k, - score_threshold=score_threshold, + score_threshold=embedding_score_threshold, filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) @@ -312,7 +336,7 @@ class RetrievalService: vector.search_by_file( file_id=query, top_k=top_k, - score_threshold=score_threshold, + score_threshold=embedding_score_threshold, filter={"group_id": [dataset.id]}, document_ids_filter=document_ids_filter, ) @@ -527,6 +551,7 @@ class RetrievalService: child_index_nodes = session.execute(child_chunk_stmt).scalars().all() for i in child_index_nodes: + assert i.index_node_id segment_ids.append(i.segment_id) if i.segment_id in child_chunk_map: child_chunk_map[i.segment_id].append(i) @@ -844,6 +869,10 @@ class RetrievalService: top_n=top_k, query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY, ) + if not data_post_processor.rerank_runner and score_threshold: + all_documents_item = self._filter_documents_by_vector_score_threshold( + all_documents_item, score_threshold + ) all_documents.extend(all_documents_item) diff --git a/api/core/rag/datasource/vdb/vector_backend_registry.py b/api/core/rag/datasource/vdb/vector_backend_registry.py new file mode 100644 index 0000000000..15f4357caf --- /dev/null +++ b/api/core/rag/datasource/vdb/vector_backend_registry.py @@ -0,0 +1,87 @@ +"""Vector store backend discovery. + +Backends live in workspace packages under ``api/packages/dify-vdb-*/src/dify_vdb_*``. Each package +declares third-party dependencies and registers ``importlib`` entry points in group +``dify.vector_backends`` (see each package's ``pyproject.toml``). + +Shared types and the :class:`~core.rag.datasource.vdb.vector_factory.AbstractVectorFactory` protocol +remain in this package (``vector_base``, ``vector_factory``, ``vector_type``, ``field``). + +Optional **built-in** targets in ``_BUILTIN_VECTOR_FACTORY_TARGETS`` (normally empty) load without a +distribution; entry points take precedence when both exist. + +After changing packages, run ``uv sync`` so installed dist-info entry points match ``pyproject.toml``. +""" + +from __future__ import annotations + +import importlib +import logging +from importlib.metadata import entry_points +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory + +logger = logging.getLogger(__name__) + +_VECTOR_FACTORY_CACHE: dict[str, type[AbstractVectorFactory]] = {} + +# module_path:class_name — optional fallback when no distribution registers the backend. +_BUILTIN_VECTOR_FACTORY_TARGETS: dict[str, str] = {} + + +def clear_vector_factory_cache() -> None: + """Drop lazily loaded factories (for tests or plugin reload).""" + _VECTOR_FACTORY_CACHE.clear() + + +def _vector_backend_entry_points(): + return entry_points().select(group="dify.vector_backends") + + +def _load_plugin_factory(vector_type: str) -> type[AbstractVectorFactory] | None: + for ep in _vector_backend_entry_points(): + if ep.name != vector_type: + continue + try: + loaded = ep.load() + except Exception: + logger.exception("Failed to load vector backend entry point %s", ep.name) + raise + return loaded # type: ignore[return-value] + return None + + +def _unsupported(vector_type: str) -> ValueError: + installed = sorted(ep.name for ep in _vector_backend_entry_points()) + available_msg = f" Installed backends: {', '.join(installed)}." if installed else " No backends installed." + return ValueError( + f"Vector store {vector_type!r} is not supported.{available_msg} " + "Install a plugin (uv sync --group vdb-all, or vdb- per api/pyproject.toml), " + "or register a dify.vector_backends entry point." + ) + + +def _load_builtin_factory(vector_type: str) -> type[AbstractVectorFactory]: + target = _BUILTIN_VECTOR_FACTORY_TARGETS.get(vector_type) + if not target: + raise _unsupported(vector_type) + module_path, _, attr = target.partition(":") + module = importlib.import_module(module_path) + return getattr(module, attr) # type: ignore[no-any-return] + + +def get_vector_factory_class(vector_type: str) -> type[AbstractVectorFactory]: + """Resolve :class:`AbstractVectorFactory` for a :class:`~VectorType` string value.""" + if vector_type in _VECTOR_FACTORY_CACHE: + return _VECTOR_FACTORY_CACHE[vector_type] + + plugin_cls = _load_plugin_factory(vector_type) + if plugin_cls is not None: + _VECTOR_FACTORY_CACHE[vector_type] = plugin_cls + return plugin_cls + + cls = _load_builtin_factory(vector_type) + _VECTOR_FACTORY_CACHE[vector_type] = cls + return cls diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 0ef88e1010..1f82f7a081 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -4,11 +4,11 @@ import time from abc import ABC, abstractmethod from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from configs import dify_config from core.model_manager import ModelManager +from core.rag.datasource.vdb.vector_backend_registry import get_vector_factory_class from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding @@ -18,6 +18,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Whitelist from models.model import UploadFile @@ -38,12 +39,84 @@ class AbstractVectorFactory(ABC): return index_struct_dict +class _LazyEmbeddings(Embeddings): + """Lazy proxy that defers materializing the real embedding model. + + Constructing the real embeddings (via ``ModelManager.get_model_instance``) + transitively calls ``FeatureService.get_features`` → ``BillingService`` + HTTP GETs (see ``provider_manager.py``). Cleanup paths + (``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings + at all, so deferring this until an ``embed_*`` method is actually invoked + keeps cleanup tasks resilient to transient billing-API failures and avoids + leaving stranded ``document_segments`` / ``child_chunks`` whenever billing + hiccups. + + Existing callers that perform create / search operations are unaffected: + the first ``embed_*`` call materializes the underlying model and the + behavior is identical from that point on. + """ + + def __init__(self, dataset: Dataset): + self._dataset = dataset + self._real: Embeddings | None = None + + def _ensure(self) -> Embeddings: + if self._real is None: + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) + embedding_model = model_manager.get_model_instance( + tenant_id=self._dataset.tenant_id, + provider=self._dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=self._dataset.embedding_model, + ) + self._real = CacheEmbedding(embedding_model) + return self._real + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return self._ensure().embed_documents(texts) + + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: + return self._ensure().embed_multimodal_documents(multimodel_documents) + + def embed_query(self, text: str) -> list[float]: + return self._ensure().embed_query(text) + + def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: + return self._ensure().embed_multimodal_query(multimodel_document) + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + return await self._ensure().aembed_documents(texts) + + async def aembed_query(self, text: str) -> list[float]: + return await self._ensure().aembed_query(text) + + class Vector: def __init__(self, dataset: Dataset, attributes: list | None = None): if attributes is None: - attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + # `is_summary` and `original_chunk_id` are stored on summary vectors + # by `SummaryIndexService` and read back by `RetrievalService` to + # route summary hits through their original parent chunks. They + # must be listed here so vector backends that use this list as an + # explicit return-properties projection (notably Weaviate) actually + # return those fields; without them, summary hits silently + # collapse into `is_summary = False` branches and the summary + # retrieval path is a no-op. See #34884. + attributes = [ + "doc_id", + "dataset_id", + "document_id", + "doc_hash", + "doc_type", + "is_summary", + "original_chunk_id", + ] self._dataset = dataset - self._embeddings = self._get_embeddings() + # Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists) + # never transitively trigger billing API calls during ``Vector(dataset)`` + # construction. The real embedding model is materialized only when an + # ``embed_*`` method is actually invoked (i.e. create / search paths). + self._embeddings: Embeddings = _LazyEmbeddings(dataset) self._attributes = attributes self._vector_processor = self._init_vector() @@ -69,140 +142,22 @@ class Vector: @staticmethod def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: - match vector_type: - case VectorType.CHROMA: - from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory + return get_vector_factory_class(vector_type) - return ChromaVectorFactory - case VectorType.MILVUS: - from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory - - return MilvusVectorFactory - case VectorType.ALIBABACLOUD_MYSQL: - from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import ( - AlibabaCloudMySQLVectorFactory, - ) - - return AlibabaCloudMySQLVectorFactory - case VectorType.MYSCALE: - from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory - - return MyScaleVectorFactory - case VectorType.PGVECTOR: - from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory - - return PGVectorFactory - case VectorType.VASTBASE: - from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory - - return VastbaseVectorFactory - case VectorType.PGVECTO_RS: - from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory - - return PGVectoRSFactory - case VectorType.QDRANT: - from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory - - return QdrantVectorFactory - case VectorType.RELYT: - from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory - - return RelytVectorFactory - case VectorType.ELASTICSEARCH: - from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory - - return ElasticSearchVectorFactory - case VectorType.ELASTICSEARCH_JA: - from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import ( - ElasticSearchJaVectorFactory, - ) - - return ElasticSearchJaVectorFactory - case VectorType.TIDB_VECTOR: - from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory - - return TiDBVectorFactory - case VectorType.WEAVIATE: - from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory - - return WeaviateVectorFactory - case VectorType.TENCENT: - from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory - - return TencentVectorFactory - case VectorType.ORACLE: - from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory - - return OracleVectorFactory - case VectorType.OPENSEARCH: - from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory - - return OpenSearchVectorFactory - case VectorType.ANALYTICDB: - from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory - - return AnalyticdbVectorFactory - case VectorType.COUCHBASE: - from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory - - return CouchbaseVectorFactory - case VectorType.BAIDU: - from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory - - return BaiduVectorFactory - case VectorType.VIKINGDB: - from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory - - return VikingDBVectorFactory - case VectorType.UPSTASH: - from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory - - return UpstashVectorFactory - case VectorType.TIDB_ON_QDRANT: - from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory - - return TidbOnQdrantVectorFactory - case VectorType.LINDORM: - from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory - - return LindormVectorStoreFactory - case VectorType.OCEANBASE | VectorType.SEEKDB: - from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory - - return OceanBaseVectorFactory - case VectorType.OPENGAUSS: - from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory - - return OpenGaussFactory - case VectorType.TABLESTORE: - from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory - - return TableStoreVectorFactory - case VectorType.HUAWEI_CLOUD: - from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory - - return HuaweiCloudVectorFactory - case VectorType.MATRIXONE: - from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory - - return MatrixoneVectorFactory - case VectorType.CLICKZETTA: - from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory - - return ClickzettaVectorFactory - case VectorType.IRIS: - from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory - - return IrisVectorFactory - case VectorType.HOLOGRES: - from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory - - return HologresVectorFactory - case _: - raise ValueError(f"Vector store {vector_type} is not supported.") + @staticmethod + def _filter_empty_text_documents(documents: list[Document]) -> list[Document]: + filtered_documents = [document for document in documents if document.page_content.strip()] + skipped_count = len(documents) - len(filtered_documents) + if skipped_count: + logger.warning("skip %d empty documents before vector embedding", skipped_count) + return filtered_documents def create(self, texts: list | None = None, **kwargs): if texts: + texts = self._filter_empty_text_documents(texts) + if not texts: + return + start = time.time() logger.info("start embedding %s texts %s", len(texts), start) batch_size = 1000 @@ -260,8 +215,14 @@ class Vector: logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start) def add_texts(self, documents: list[Document], **kwargs): + documents = self._filter_empty_text_documents(documents) + if not documents: + return + if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) + if not documents: + return embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs) diff --git a/api/tests/integration_tests/vdb/test_vector_store.py b/api/core/rag/datasource/vdb/vector_integration_test_support.py similarity index 83% rename from api/tests/integration_tests/vdb/test_vector_store.py rename to api/core/rag/datasource/vdb/vector_integration_test_support.py index a033443cf8..3148b7d5c1 100644 --- a/api/tests/integration_tests/vdb/test_vector_store.py +++ b/api/core/rag/datasource/vdb/vector_integration_test_support.py @@ -1,10 +1,19 @@ +"""Shared helpers for vector DB integration tests (used by workspace packages under ``api/packages``). + +:class:`AbstractVectorTest` and helper functions live here so package tests can import +``core.rag.datasource.vdb.vector_integration_test_support`` without relying on the +``tests.*`` package. + +The ``setup_mock_redis`` fixture lives in ``api/packages/conftest.py`` and is +auto-discovered by pytest for all package tests. +""" + import uuid -from unittest.mock import MagicMock import pytest +from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.models.document import Document -from extensions import ext_redis from models.dataset import Dataset @@ -25,24 +34,10 @@ def get_example_document(doc_id: str) -> Document: return doc -@pytest.fixture -def setup_mock_redis(): - # get - ext_redis.redis_client.get = MagicMock(return_value=None) - - # set - ext_redis.redis_client.set = MagicMock(return_value=None) - - # lock - mock_redis_lock = MagicMock() - mock_redis_lock.__enter__ = MagicMock() - mock_redis_lock.__exit__ = MagicMock() - ext_redis.redis_client.lock = mock_redis_lock - - class AbstractVectorTest: + vector: BaseVector + def __init__(self): - self.vector = None self.dataset_id = str(uuid.uuid4()) self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test" self.example_doc_id = str(uuid.uuid4()) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 40f45953af..78305a6ac0 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -3,14 +3,15 @@ from __future__ import annotations from collections.abc import Sequence from typing import Any -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, func, select from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding +from models.enums import SegmentType class DatasetDocumentStore: @@ -127,6 +128,7 @@ class DatasetDocumentStore: if save_child: if doc.children: for position, child in enumerate(doc.children, start=1): + assert self._document_id child_segment = ChildChunk( tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, @@ -137,7 +139,7 @@ class DatasetDocumentStore: index_node_hash=child.metadata.get("doc_hash"), content=child.page_content, word_count=len(child.page_content), - type="automatic", + type=SegmentType.AUTOMATIC, created_by=self._user_id, ) db.session.add(child_segment) @@ -163,6 +165,7 @@ class DatasetDocumentStore: ) # add new child chunks for position, child in enumerate(doc.children, start=1): + assert self._document_id child_segment = ChildChunk( tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, @@ -173,7 +176,7 @@ class DatasetDocumentStore: index_node_hash=child.metadata.get("doc_hash"), content=child.page_content, word_count=len(child.page_content), - type="automatic", + type=SegmentType.AUTOMATIC, created_by=self._user_id, ) db.session.add(child_segment) @@ -244,7 +247,7 @@ class DatasetDocumentStore: return document_segment def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None): - if multimodel_documents: + if multimodel_documents and self._document_id is not None: for multimodel_document in multimodel_documents: binding = SegmentAttachmentBinding( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 8d1c0da392..a9995778f7 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,8 +4,6 @@ import pickle from typing import Any, cast import numpy as np -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -15,6 +13,8 @@ from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from models.dataset import Embedding @@ -106,7 +106,7 @@ class CacheEmbedding(Embeddings): return text_embeddings - def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: """Embed file documents.""" # use doc embedding cache or store if not exists multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))] @@ -232,7 +232,7 @@ class CacheEmbedding(Embeddings): return embedding_results # type: ignore - def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: """Embed multimodal documents.""" # use doc embedding cache or store if not exists file_id = multimodel_document["file_id"] diff --git a/api/core/rag/embedding/embedding_base.py b/api/core/rag/embedding/embedding_base.py index 1be55bda80..7ae5c09ab7 100644 --- a/api/core/rag/embedding/embedding_base.py +++ b/api/core/rag/embedding/embedding_base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any class Embeddings(ABC): @@ -10,7 +11,7 @@ class Embeddings(ABC): raise NotImplementedError @abstractmethod - def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: """Embed file documents.""" raise NotImplementedError @@ -20,7 +21,7 @@ class Embeddings(ABC): raise NotImplementedError @abstractmethod - def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: """Embed multimodal query.""" raise NotImplementedError diff --git a/api/core/rag/entities/__init__.py b/api/core/rag/entities/__init__.py index 63c6708704..373b68894b 100644 --- a/api/core/rag/entities/__init__.py +++ b/api/core/rag/entities/__init__.py @@ -4,7 +4,12 @@ from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEve from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation -from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig +from core.rag.entities.retrieval_settings import ( + KeywordSetting, + RerankingModelConfig, + VectorSetting, + WeightedScoreConfig, +) __all__ = [ "Condition", @@ -19,6 +24,7 @@ __all__ = [ "MetadataFilteringCondition", "ParentMode", "PreProcessingRule", + "RerankingModelConfig", "RetrievalSourceMetadata", "Rule", "Segmentation", diff --git a/api/core/rag/entities/retrieval_settings.py b/api/core/rag/entities/retrieval_settings.py index a0c6512c9c..8d40ab68fd 100644 --- a/api/core/rag/entities/retrieval_settings.py +++ b/api/core/rag/entities/retrieval_settings.py @@ -1,4 +1,27 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field + + +class RerankingModelConfig(BaseModel): + """ + Canonical reranking model configuration. + + Accepts both naming conventions: + - reranking_provider_name / reranking_model_name (services layer) + - provider / model (workflow layer via validation_alias) + """ + + model_config = ConfigDict(populate_by_name=True) + + reranking_provider_name: str = Field(validation_alias="provider") + reranking_model_name: str = Field(validation_alias="model") + + @property + def provider(self) -> str: + return self.reranking_provider_name + + @property + def model(self) -> str: + return self.reranking_model_name class VectorSetting(BaseModel): diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 3bfae9d6bd..19bc9cec84 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -1,6 +1,7 @@ """Abstract interface for document loader implementations.""" import csv +from typing import Any import pandas as pd @@ -23,7 +24,7 @@ class CSVExtractor(BaseExtractor): encoding: str | None = None, autodetect_encoding: bool = False, source_column: str | None = None, - csv_args: dict | None = None, + csv_args: dict[str, Any] | None = None, ): """Initialize with file path.""" self._file_path = file_path diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 449be6a448..b679edab36 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -94,16 +94,18 @@ class ExtractProcessor: cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE: + upload_file = extract_setting.upload_file with tempfile.TemporaryDirectory() as temp_dir: + upload_file = extract_setting.upload_file if not file_path: - assert extract_setting.upload_file is not None, "upload_file is required" - upload_file: UploadFile = extract_setting.upload_file + assert upload_file is not None, "upload_file is required" suffix = Path(upload_file.key).suffix # FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore storage.download(upload_file.key, file_path) input_file = Path(file_path) file_extension = input_file.suffix.lower() + assert upload_file is not None, "upload_file is required" etl_type = dify_config.ETL_TYPE extractor: BaseExtractor | None = None if etl_type == "Unstructured": @@ -113,6 +115,7 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": + assert upload_file is not None extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = ( @@ -123,6 +126,7 @@ class ExtractProcessor: elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": + assert upload_file is not None extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == ".doc": extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key) @@ -149,12 +153,14 @@ class ExtractProcessor: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": + assert upload_file is not None extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension in {".md", ".markdown", ".mdx"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": + assert upload_file is not None extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 89bdd56a6c..556158cf00 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -174,21 +174,25 @@ class FirecrawlApp: return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}" def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> httpx.Response: + response: httpx.Response | None = None for attempt in range(retries): response = httpx.post(url, headers=headers, json=data) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: return response + assert response is not None, "retries must be at least 1" return response def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> httpx.Response: + response: httpx.Response | None = None for attempt in range(retries): response = httpx.get(url, headers=headers) if response.status_code == 502: time.sleep(backoff_factor * (2**attempt)) else: return response + assert response is not None, "retries must be at least 1" return response def _handle_error(self, response, action): diff --git a/api/core/rag/extractor/watercrawl/client.py b/api/core/rag/extractor/watercrawl/client.py index 7b4a388df9..d1ce142dbd 100644 --- a/api/core/rag/extractor/watercrawl/client.py +++ b/api/core/rag/extractor/watercrawl/client.py @@ -54,8 +54,8 @@ class BaseAPIClient: self, method: str, endpoint: str, - query_params: dict | None = None, - data: dict | None = None, + query_params: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, **kwargs, ) -> Response: stream = kwargs.pop("stream", False) @@ -66,19 +66,25 @@ class BaseAPIClient: return self.session.request(method, url, params=query_params, json=data, **kwargs) - def _get(self, endpoint: str, query_params: dict | None = None, **kwargs): + def _get(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs): return self._request("GET", endpoint, query_params=query_params, **kwargs) - def _post(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): + def _post( + self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs + ): return self._request("POST", endpoint, query_params=query_params, data=data, **kwargs) - def _put(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): + def _put( + self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs + ): return self._request("PUT", endpoint, query_params=query_params, data=data, **kwargs) - def _delete(self, endpoint: str, query_params: dict | None = None, **kwargs): + def _delete(self, endpoint: str, query_params: dict[str, Any] | None = None, **kwargs): return self._request("DELETE", endpoint, query_params=query_params, **kwargs) - def _patch(self, endpoint: str, query_params: dict | None = None, data: dict | None = None, **kwargs): + def _patch( + self, endpoint: str, query_params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, **kwargs + ): return self._request("PATCH", endpoint, query_params=query_params, data=data, **kwargs) @@ -99,7 +105,7 @@ class WaterCrawlAPIClient(BaseAPIClient): finally: response.close() - def process_response(self, response: Response) -> dict | bytes | list | None | Generator: + def process_response(self, response: Response) -> dict[str, Any] | bytes | list[Any] | None | Generator: if response.status_code == 401: raise WaterCrawlAuthenticationError(response) @@ -186,7 +192,7 @@ class WaterCrawlAPIClient(BaseAPIClient): yield from generator def get_crawl_request_results( - self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict | None = None + self, item_id: str, page: int = 1, page_size: int = 25, query_params: dict[str, Any] | None = None ): query_params = query_params or {} query_params.update({"page": page or 1, "page_size": page_size or 25}) @@ -210,7 +216,7 @@ class WaterCrawlAPIClient(BaseAPIClient): if event_data["type"] == "result": return event_data["data"] - def download_result(self, result_object: dict): + def download_result(self, result_object: dict[str, Any]): response = httpx.get(result_object["result"], timeout=None) try: response.raise_for_status() diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index 2a9403eda0..ae7bebcb9b 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -120,7 +120,7 @@ class WaterCrawlProvider: } def _get_results( - self, crawl_request_id: str, query_params: dict | None = None + self, crawl_request_id: str, query_params: dict[str, Any] | None = None ) -> Generator[WatercrawlDocumentData, None, None]: page = 0 page_size = 100 diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 052fca930d..0330a43b28 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -3,6 +3,7 @@ Supports local file paths and remote URLs (downloaded via `core.helper.ssrf_proxy`). """ +import inspect import logging import mimetypes import os @@ -36,8 +37,11 @@ class WordExtractor(BaseExtractor): file_path: Path to the file to load. """ + _closed: bool + def __init__(self, file_path: str, tenant_id: str, user_id: str): """Initialize with file path.""" + self._closed = False self.file_path = file_path self.tenant_id = tenant_id self.user_id = user_id @@ -65,9 +69,27 @@ class WordExtractor(BaseExtractor): elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") + def close(self) -> None: + """Best-effort cleanup for downloaded temporary files.""" + if getattr(self, "_closed", False): + return + + self._closed = True + temp_file = getattr(self, "temp_file", None) + if temp_file is None: + return + + try: + close_result = temp_file.close() + if inspect.isawaitable(close_result): + close_awaitable = getattr(close_result, "close", None) + if callable(close_awaitable): + close_awaitable() + except Exception: + logger.debug("Failed to cleanup downloaded word temp file", exc_info=True) + def __del__(self): - if hasattr(self, "temp_file"): - self.temp_file.close() + self.close() def extract(self) -> list[Document]: """Load given path as single page.""" diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py index 813a84cbbd..aded5315bd 100644 --- a/api/core/rag/index_processor/index_processor.py +++ b/api/core/rag/index_processor/index_processor.py @@ -6,7 +6,7 @@ from collections.abc import Mapping from typing import Any from flask import current_app -from sqlalchemy import delete, func, select +from sqlalchemy import delete, func, select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -63,11 +63,11 @@ class IndexProcessor: summary_index_setting: SummaryIndexSettingDict | None = None, ) -> IndexingResultDict: with session_factory.create_session() as session: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not document: raise KnowledgeIndexNodeError(f"Document {document_id} not found.") - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") @@ -104,12 +104,12 @@ class IndexProcessor: document.indexing_status = "completed" document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.word_count = ( - session.query(func.sum(DocumentSegment.word_count)) - .where( - DocumentSegment.document_id == document_id, - DocumentSegment.dataset_id == dataset_id, + session.scalar( + select(func.sum(DocumentSegment.word_count)).where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) ) - .scalar() ) or 0 # Update need_summary based on dataset's summary_index_setting if summary_index_setting and summary_index_setting.get("enable") is True: @@ -118,15 +118,17 @@ class IndexProcessor: document.need_summary = False session.add(document) # update document segment status - session.query(DocumentSegment).where( - DocumentSegment.document_id == document_id, - DocumentSegment.dataset_id == dataset_id, - ).update( - { - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - } + session.execute( + update(DocumentSegment) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) + .values( + status="completed", + enabled=True, + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + ) ) result: IndexingResultDict = { @@ -151,11 +153,11 @@ class IndexProcessor: doc_language = None with session_factory.create_session() as session: if document_id: - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) else: document = None - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index a487c49053..7ffa9afafd 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -7,16 +7,6 @@ from typing import Any, TypedDict, cast logger = logging.getLogger(__name__) -from graphon.file import File, FileTransferMethod, FileType, file_manager -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -43,6 +33,16 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories.file_factory import build_from_mapping +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from libs import helper from models import UploadFile from models.account import Account @@ -609,11 +609,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): try: # Create File object directly (similar to DatasetRetrieval) file_obj = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 2db233874a..ba277d5018 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -159,14 +159,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor): if node_ids: # Find segments by index_node_id with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment) - .filter( + segments = session.scalars( + select(DocumentSegment).where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ) - .all() - ) + ).all() segment_ids = [segment.id for segment in segments] if segment_ids: SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index b0f7928092..d3f311b08e 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -8,6 +8,7 @@ from typing import Any, TypedDict import pandas as pd from flask import Flask, current_app +from sqlalchemy import select from werkzeug.datastructures import FileStorage from core.db.session_factory import session_factory @@ -163,14 +164,12 @@ class QAIndexProcessor(BaseIndexProcessor): if node_ids: # Find segments by index_node_id with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment) - .filter( + segments = session.scalars( + select(DocumentSegment).where( DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(node_ids), ) - .all() - ) + ).all() segment_ids = [segment.id for segment in segments] if segment_ids: SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids) diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 087736d0b0..4ebf095904 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any -from graphon.file import File from pydantic import BaseModel, Field +from graphon.file import File + class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index a8d37845a5..bce08f998f 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,8 +1,5 @@ import base64 -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult - from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType @@ -10,6 +7,8 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from models.model import UploadFile diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 49123e13d0..d0732b269a 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -2,7 +2,6 @@ import math from collections import Counter import numpy as np -from graphon.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelManager from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -12,6 +11,7 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner +from graphon.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 0f3351fd68..5631b3a921 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,12 +9,7 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app -from graphon.file import File, FileTransferMethod, FileType -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from sqlalchemy import and_, func, literal, or_, select +from sqlalchemy import and_, func, literal, or_, select, update from sqlalchemy.orm import sessionmaker from core.app.app_config.entities import ( @@ -69,6 +64,11 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( ) from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile @@ -276,8 +276,8 @@ class DatasetRetrieval: document_ids = [i.segment.document_id for i in records] with session_factory.create_session() as session: - datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() - documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all() + datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() + documents = session.scalars(select(DatasetDocument).where(DatasetDocument.id.in_(document_ids))).all() dataset_map = {i.id: i for i in datasets} document_map = {i.id: i for i in documents} @@ -517,11 +517,11 @@ class DatasetRetrieval: if attachments_with_bindings: for _, upload_file in attachments_with_bindings: attachment_info = File( - id=upload_file.id, + file_id=upload_file.id, filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, reference=build_file_reference( @@ -875,7 +875,11 @@ class DatasetRetrieval: return retrieval_resource_list def _on_retrieval_end( - self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None + self, + flask_app: Flask, + documents: list[Document], + message_id: str | None = None, + timer: dict[str, Any] | None = None, ): """Handle retrieval end.""" with flask_app.app_context(): @@ -971,14 +975,16 @@ class DatasetRetrieval: # Batch update hit_count for all segments if segment_ids_to_update: - session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids_to_update)).update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, + session.execute( + update(DocumentSegment) + .where(DocumentSegment.id.in_(segment_ids_to_update)) + .values(hit_count=DocumentSegment.hit_count + 1) + .execution_options(synchronize_session=False) ) self._send_trace_task(message_id, documents, timer) - def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict | None): + def _send_trace_task(self, message_id: str | None, documents: list[Document], timer: dict[str, Any] | None): """Send trace task if trace manager is available.""" trace_manager: TraceQueueManager | None = ( self.application_generate_entity.trace_manager if self.application_generate_entity else None @@ -1140,7 +1146,7 @@ class DatasetRetrieval: invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, user_id: str, - inputs: dict, + inputs: dict[str, Any], ) -> list[DatasetRetrieverBaseTool] | None: """ A dataset tool is a tool that can be used to retrieve information from a dataset @@ -1335,7 +1341,7 @@ class DatasetRetrieval: metadata_filtering_mode: str, metadata_model_config: ModelConfig, metadata_filtering_conditions: MetadataFilteringCondition | None, - inputs: dict, + inputs: dict[str, Any], ) -> tuple[dict[str, list[str]] | None, MetadataFilteringCondition | None]: document_query = select(DatasetDocument).where( DatasetDocument.dataset_id.in_(dataset_ids), @@ -1415,7 +1421,7 @@ class DatasetRetrieval: metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore return metadata_filter_document_ids, metadata_condition - def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str: + def _replace_metadata_filter_value(self, text: str, inputs: dict[str, Any]) -> str: if not inputs: return text @@ -1822,7 +1828,7 @@ class DatasetRetrieval: def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]: with session_factory.create_session() as session: subquery = ( - session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count")) + select(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count")) .where( DocumentModel.indexing_status == "completed", DocumentModel.enabled == True, @@ -1834,13 +1840,12 @@ class DatasetRetrieval: .subquery() ) - results = ( - session.query(Dataset) + results = session.scalars( + select(Dataset) .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) .where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids)) .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) - .all() - ) + ).all() available_datasets = [] for dataset in results: diff --git a/api/core/rag/retrieval/output_parser/react_output.py b/api/core/rag/retrieval/output_parser/react_output.py index 9a14d41716..29abae4280 100644 --- a/api/core/rag/retrieval/output_parser/react_output.py +++ b/api/core/rag/retrieval/output_parser/react_output.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import NamedTuple, Union +from typing import Any, NamedTuple, Union @dataclass @@ -10,7 +10,7 @@ class ReactAction: tool: str """The name of the Tool to execute.""" - tool_input: Union[str, dict] + tool_input: Union[str, dict[str, Any]] """The input to pass in to the Tool.""" log: str """Additional information to log about the action.""" @@ -19,7 +19,7 @@ class ReactAction: class ReactFinish(NamedTuple): """The final return value of an ReactFinish.""" - return_values: dict + return_values: dict[str, Any] """Dictionary of return values.""" log: str """Additional information to log about the return value""" diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index dce7b6226c..dd17545c86 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,10 +1,9 @@ from typing import Union -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: @@ -29,10 +28,10 @@ class FunctionCallMultiDatasetRouter: SystemPromptMessage(content="You are a helpful AI assistant."), UserPromptMessage(content=query), ] - result: LLMResult = model_instance.invoke_llm( + result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType] prompt_messages=prompt_messages, tools=dataset_tools, - stream=False, + stream=False, # pyright: ignore[reportArgumentType] model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) usage = result.usage or LLMUsage.empty_usage() diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index dd280cdf6a..21a9d04f7f 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,9 +1,5 @@ from collections.abc import Generator, Sequence -from typing import Union - -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from graphon.model_runtime.entities.model_entities import ModelType +from typing import Any, Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota @@ -12,6 +8,9 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelType PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -139,7 +138,7 @@ class ReactMultiDatasetRouter: def _invoke_llm( self, - completion_param: dict, + completion_param: dict[str, Any], model_instance: ModelInstance, prompt_messages: list[PromptMessage], stop: list[str], diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 3383c7f3bd..52c9a02f97 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -4,13 +4,12 @@ from __future__ import annotations import codecs import re -from collections.abc import Collection +from collections.abc import Set as AbstractSet from typing import Any, Literal -from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer - from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import RecursiveCharacterTextSplitter +from graphon.model_runtime.model_providers.base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @@ -22,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder[T: EnhanceRecursiveCharacterTextSplitter]( cls: type[T], embedding_model_instance: ModelInstance | None, - allowed_special: Literal["all"] | set[str] = set(), - disallowed_special: Literal["all"] | Collection[str] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = frozenset(), + disallowed_special: Literal["all"] | AbstractSet[str] = "all", **kwargs: Any, ) -> T: def _token_encoder(texts: list[str]) -> list[int]: @@ -41,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): return [len(text) for text in texts] + _ = _token_encoder # kept for future token-length wiring return cls(length_function=_character_encoder, **kwargs) diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 8977611f93..a8d9013fbc 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -4,7 +4,8 @@ import copy import logging import re from abc import ABC, abstractmethod -from collections.abc import Callable, Collection, Iterable, Sequence, Set +from collections.abc import Callable, Iterable, Sequence +from collections.abc import Set as AbstractSet from dataclasses import dataclass from typing import Any, Literal @@ -63,7 +64,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents(self, texts: list[str], metadatas: list[dict] | None = None) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: list[dict[str, Any]] | None = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter): self, encoding_name: str = "gpt2", model_name: str | None = None, - allowed_special: Literal["all"] | Set[str] = set(), - disallowed_special: Literal["all"] | Collection[str] = "all", + allowed_special: Literal["all"] | AbstractSet[str] = frozenset(), + disallowed_special: Literal["all"] | AbstractSet[str] = "all", **kwargs: Any, ): """Create a new TextSplitter.""" @@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter): else: enc = tiktoken.get_encoding(encoding_name) self._tokenizer = enc - self._allowed_special = allowed_special - self._disallowed_special = disallowed_special + self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special + self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special def split_text(self, text: str) -> list[str]: def _encode(_text: str) -> list[int]: diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py index 6f120bd471..bff5f85dec 100644 --- a/api/core/rag/summary_index/summary_index.py +++ b/api/core/rag/summary_index/summary_index.py @@ -1,6 +1,8 @@ import concurrent.futures import logging +from sqlalchemy import select + from core.db.session_factory import session_factory from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict @@ -21,7 +23,7 @@ class SummaryIndex: ) -> None: if is_preview: with session_factory.create_session() as session: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset or dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY: return @@ -34,32 +36,31 @@ class SummaryIndex: if not document_id: return - document = session.query(Document).filter_by(id=document_id).first() + document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) # Skip qa_model documents if document is None or document.doc_form == "qa_model": return - query = session.query(DocumentSegment).filter_by( - dataset_id=dataset_id, - document_id=document_id, - status="completed", - enabled=True, - ) - segments = query.all() + segments = session.scalars( + select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + ).all() segment_ids = [segment.id for segment in segments] if not segment_ids: return - existing_summaries = ( - session.query(DocumentSegmentSummary) - .filter( + existing_summaries = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset_id, DocumentSegmentSummary.status == "completed", ) - .all() - ) + ).all() completed_summary_segment_ids = {i.chunk_id for i in existing_summaries} # Preview mode should process segments that are MISSING completed summaries pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids] @@ -73,7 +74,7 @@ class SummaryIndex: def process_segment(segment_id: str) -> None: """Process a single segment in a thread with a fresh DB session.""" with session_factory.create_session() as session: - segment = session.query(DocumentSegment).filter_by(id=segment_id).first() + segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)) if segment is None: return try: diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index b07c63fdf0..e87d1cd6b2 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -7,11 +7,11 @@ providing improved performance by offloading database operations to background w import logging -from graphon.entities import WorkflowExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities import WorkflowExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index cdb3af01a8..2451563317 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -8,7 +8,6 @@ providing improved performance by offloading database operations to background w import logging from collections.abc import Sequence -from graphon.entities import WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -16,6 +15,7 @@ from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) +from graphon.entities import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index ce3ad15759..4e83e70799 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -9,11 +9,11 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Protocol -from graphon.entities import WorkflowExecution, WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 72d9394149..740d727e26 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -4,13 +4,11 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Protocol -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, @@ -19,6 +17,8 @@ from core.workflow.human_input_compat import ( InteractiveSurfaceDeliveryMethod, is_human_input_webapp_enabled, ) +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.account import Account, TenantAccountJoin diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index d74cc8f231..6be3902317 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -5,13 +5,13 @@ SQLAlchemy implementation of the WorkflowExecutionRepository. import json import logging -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 13e885672a..b036687bc9 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -10,10 +10,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any import psycopg2.errors -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError @@ -23,6 +19,10 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att from configs import dify_config from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 from models import ( diff --git a/api/core/schemas/resolver.py b/api/core/schemas/resolver.py index 6e26664ac2..e267c1abd9 100644 --- a/api/core/schemas/resolver.py +++ b/api/core/schemas/resolver.py @@ -254,7 +254,7 @@ def resolve_dify_schema_refs( return resolver.resolve(schema) -def _remove_metadata_fields(schema: dict) -> dict: +def _remove_metadata_fields(schema: dict[str, Any]) -> dict[str, Any]: """ Remove metadata fields from schema that shouldn't be included in resolved output diff --git a/api/core/telemetry/gateway.py b/api/core/telemetry/gateway.py index 7b013d0563..812edeeb14 100644 --- a/api/core/telemetry/gateway.py +++ b/api/core/telemetry/gateway.py @@ -89,7 +89,7 @@ def _get_case_routing() -> dict[TelemetryCase, CaseRoute]: return _case_routing -def __getattr__(name: str) -> dict: +def __getattr__(name: str) -> Any: """Lazy module-level access to routing tables.""" if name == "CASE_ROUTING": return _get_case_routing() diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 7bb2cdb876..ab0f73a9a2 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -198,7 +198,7 @@ class Tool(ABC): message=ToolInvokeMessage.TextMessage(text=text), ) - def create_blob_message(self, blob: bytes, meta: dict | None = None) -> ToolInvokeMessage: + def create_blob_message(self, blob: bytes, meta: dict[str, Any] | None = None) -> ToolInvokeMessage: """ create a blob message @@ -212,7 +212,7 @@ class Tool(ABC): meta=meta, ) - def create_json_message(self, object: dict, suppress_output: bool = False) -> ToolInvokeMessage: + def create_json_message(self, object: dict[str, Any], suppress_output: bool = False) -> ToolInvokeMessage: """ create a json message """ diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index e539074303..95660ab93b 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -2,15 +2,14 @@ import io from collections.abc import Generator from typing import Any -from graphon.file import FileType -from graphon.file.file_manager import download -from graphon.model_runtime.entities.model_entities import ModelType - from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from graphon.file import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index f49c669fe0..ac3820f1ab 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -2,13 +2,12 @@ import io from collections.abc import Generator from typing import Any -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType - from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 14af63a962..d41503e1e6 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,12 +1,11 @@ from __future__ import annotations -from graphon.model_runtime.entities.llm_entities import LLMResult -from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage - from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 0a2c37c563..168e5f4493 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -6,7 +6,6 @@ from typing import Any, Union from urllib.parse import urlencode import httpx -from graphon.file.file_manager import download from core.helper import ssrf_proxy from core.tools.__base.tool import Tool @@ -14,6 +13,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError +from graphon.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 410ec72baf..42a88c0003 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,7 +2,6 @@ from collections.abc import Mapping from datetime import datetime from typing import Any, Literal -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -10,6 +9,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from graphon.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index 10710c4376..4e07b7157a 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -1,4 +1,5 @@ from collections.abc import Mapping +from typing import Any from pydantic import BaseModel, Field @@ -26,6 +27,6 @@ class ApiToolBundle(BaseModel): # icon icon: str | None = None # openapi operation - openapi: dict + openapi: dict[str, Any] # output schema output_schema: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 31e879add2..0c77693dde 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -149,7 +149,7 @@ class ToolInvokeMessage(BaseModel): text: str class JsonMessage(BaseModel): - json_object: dict | list + json_object: dict[str, Any] | list[Any] suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string") class BlobMessage(BaseModel): @@ -337,7 +337,7 @@ class ToolParameter(PluginParameter): form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: str | None = None # MCP object and array type parameters use this field to store the schema - input_schema: dict | None = None + input_schema: dict[str, Any] | None = None @classmethod def get_simple_instance( @@ -450,6 +450,12 @@ class WorkflowToolParameterConfiguration(BaseModel): form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") +class ToolInvokeMetaDict(TypedDict): + time_cost: float + error: str | None + tool_config: dict[str, Any] | None + + class ToolInvokeMeta(BaseModel): """ Tool invoke meta @@ -457,7 +463,7 @@ class ToolInvokeMeta(BaseModel): time_cost: float = Field(..., description="The time cost of the tool invoke") error: str | None = None - tool_config: dict | None = None + tool_config: dict[str, Any] | None = None @classmethod def empty(cls) -> ToolInvokeMeta: @@ -473,12 +479,13 @@ class ToolInvokeMeta(BaseModel): """ return cls(time_cost=0.0, error=error, tool_config={}) - def to_dict(self): - return { + def to_dict(self) -> ToolInvokeMetaDict: + result: ToolInvokeMetaDict = { "time_cost": self.time_cost, "error": self.error, "tool_config": self.tool_config, } + return result class ToolLabel(BaseModel): diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 4c3efd6ff9..2b26832b44 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -38,6 +38,17 @@ class ToolCredentialPolicyViolationError(ValueError): pass +class ApiToolProviderNotFoundError(ValueError): + error_code = "api_tool_provider_not_found" + provider_name: str + tenant_id: str + + def __init__(self, provider_name: str, tenant_id: str): + self.provider_name = provider_name + self.tenant_id = tenant_id + super().__init__(f"api provider {provider_name} does not exist") + + class WorkflowToolHumanInputNotSupportedError(BaseHTTPException): error_code = "workflow_tool_human_input_not_supported" description = "Workflow with Human Input nodes cannot be published as a workflow tool." diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index f6d09472b3..00fc8a8282 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -6,8 +6,6 @@ import logging from collections.abc import Generator, Mapping from typing import Any, cast -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata - from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import ( @@ -23,6 +21,7 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 685d687d8c..3caacb8706 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -7,7 +7,6 @@ from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Union, cast -from graphon.file import FileTransferMethod, FileType from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom @@ -33,6 +32,7 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.file import FileTransferMethod, FileType from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile @@ -47,7 +47,7 @@ class ToolEngine: @staticmethod def agent_invoke( tool: Tool, - tool_parameters: Union[str, dict], + tool_parameters: Union[str, dict[str, Any]], user_id: str, tenant_id: str, message: Message, @@ -85,7 +85,8 @@ class ToolEngine: invocation_meta_dict: dict[str, ToolInvokeMeta] = {} def message_callback( - invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None] + invocation_meta_dict: dict[str, ToolInvokeMeta], + messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None], ): for message in messages: if isinstance(message, ToolInvokeMeta): @@ -200,7 +201,7 @@ class ToolEngine: @staticmethod def _invoke( tool: Tool, - tool_parameters: dict, + tool_parameters: dict[str, Any], user_id: str, conversation_id: str | None = None, app_id: str | None = None, @@ -262,6 +263,8 @@ class ToolEngine: ensure_ascii=False, ) ) + elif response.type == ToolInvokeMessage.MessageType.VARIABLE: + continue else: parts.append(str(response.message)) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index a59d167a0a..c87e8a3ae0 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -6,11 +6,9 @@ import os import time from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Union from uuid import uuid4 import httpx -from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from sqlalchemy import select from configs import dify_config @@ -18,6 +16,7 @@ from core.db.session_factory import session_factory from core.helper import ssrf_proxy from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from models.model import MessageFile from models.tools import ToolFile @@ -29,7 +28,7 @@ class ToolFileManager: def _build_graph_file_reference(tool_file: ToolFile) -> File: extension = guess_extension(tool_file.mimetype) or ".bin" return File( - type=get_file_type_by_mime_type(tool_file.mimetype), + file_type=get_file_type_by_mime_type(tool_file.mimetype), transfer_method=FileTransferMethod.TOOL_FILE, remote_url=tool_file.original_url, reference=build_file_reference(record_id=str(tool_file.id)), @@ -158,7 +157,7 @@ class ToolFileManager: return tool_file - def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary(self, id: str) -> tuple[bytes, str] | None: """ get file binary @@ -176,7 +175,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary_by_message_file_id(self, id: str) -> tuple[bytes, str] | None: """ get file binary diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 58190d1089..d8969a3391 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,4 +1,5 @@ from sqlalchemy import delete, select +from sqlalchemy.orm import Session, sessionmaker from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -19,10 +20,18 @@ class ToolLabelManager: return list(set(tool_labels)) @classmethod - def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): + def update_tool_labels( + cls, controller: ToolProviderController, labels: list[str], session: Session | None = None + ) -> None: """ Update tool labels + + :param controller: tool provider controller + :param labels: list of tool labels + :param session: database session, if None, a new session will be created + :return: None """ + labels = cls.filter_tool_labels(labels) if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): @@ -30,26 +39,46 @@ class ToolLabelManager: else: raise ValueError("Unsupported tool type") + if session is not None: + cls._update_tool_labels_logics(session, provider_id, controller, labels) + else: + with sessionmaker(db.engine).begin() as _session: + cls._update_tool_labels_logics(_session, provider_id, controller, labels) + + @classmethod + def _update_tool_labels_logics( + cls, session: Session, provider_id: str, controller: ToolProviderController, labels: list[str] + ) -> None: + """ + Update tool labels logics + + :param session: database session + :param provider_id: tool provider ID + :param controller: tool provider controller + :param labels: list of tool labels + :return: None + """ + # delete old labels - db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id)) + _ = session.execute( + delete(ToolLabelBinding).where( + ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type + ) + ) # insert new labels for label in labels: - db.session.add( - ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type, - label_name=label, - ) - ) - - db.session.commit() + session.add(ToolLabelBinding(tool_id=provider_id, tool_type=controller.provider_type, label_name=label)) @classmethod def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: """ Get tool labels + + :param controller: tool provider controller + :return: list of tool labels (str) """ + if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): provider_id = controller.provider_id elif isinstance(controller, BuiltinToolProviderController): @@ -60,9 +89,11 @@ class ToolLabelManager: ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_type == controller.provider_type, ) - labels = db.session.scalars(stmt).all() - return list(labels) + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + labels: list[str] = list(_session.scalars(stmt).all()) + + return labels @classmethod def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: @@ -78,16 +109,22 @@ class ToolLabelManager: if not tool_providers: return {} + provider_ids: list[str] = [] + provider_types: set[str] = set() + for controller in tool_providers: if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): raise ValueError("Unsupported tool type") - - provider_ids = [] - for controller in tool_providers: - assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) provider_ids.append(controller.provider_id) + provider_types.add(controller.provider_type) - labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() + labels: list[ToolLabelBinding] = [] + + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + stmt = select(ToolLabelBinding).where( + ToolLabelBinding.tool_id.in_(provider_ids), ToolLabelBinding.tool_type.in_(list(provider_types)) + ) + labels = list(_session.scalars(stmt).all()) tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f8f07369d0..0a7811bb53 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -8,7 +8,6 @@ from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Protocol, cast import sqlalchemy as sa -from graphon.runtime import VariablePool from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.orm import Session @@ -29,14 +28,13 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from extensions.ext_database import db +from graphon.runtime import VariablePool from models.provider_ids import ToolProviderID from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: pass -from graphon.model_runtime.utils.encoders import jsonable_encoder - from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source @@ -62,6 +60,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -682,7 +681,7 @@ class ToolManager: with Session(db.engine, autoflush=False) as session: ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] - return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() + return list(session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)))) @classmethod def list_providers_from_api( @@ -993,7 +992,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str: + def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | str: try: with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session) @@ -1001,7 +1000,7 @@ class ToolManager: mcp_provider = mcp_service.get_provider_entity( provider_id=provider_id, tenant_id=tenant_id, by_server_id=True ) - return mcp_provider.provider_icon + return cast(EmojiIconDict | str, mcp_provider.provider_icon) except ValueError: raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") except Exception: @@ -1013,7 +1012,7 @@ class ToolManager: tenant_id: str, provider_type: ToolProviderType, provider_id: str, - ) -> str | EmojiIconDict | dict[str, str]: + ) -> str | EmojiIconDict: """ get the tool icon @@ -1079,11 +1078,23 @@ class ToolManager: if parameter.form == ToolParameter.ToolParameterForm.FORM: if variable_pool: config = tool_configurations.get(parameter.name, {}) + + selector_value = cls._extract_runtime_selector_value(parameter, config) + if selector_value is not None: + # Selector parameters carry structured dictionaries, not scalar ToolInput values. + runtime_parameters[parameter.name] = selector_value + continue + if not (config and isinstance(config, dict) and config.get("value") is not None): continue tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) + variable_selector = tool_input.value + if not isinstance(variable_selector, list) or not all( + isinstance(selector_part, str) for selector_part in variable_selector + ): + raise ToolParameterError("Variable tool input must be a variable selector") + variable = variable_pool.get(variable_selector) if variable is None: raise ToolParameterError(f"Variable {tool_input.value} does not exist") parameter_value = variable.value @@ -1101,5 +1112,39 @@ class ToolManager: runtime_parameters[parameter.name] = value return runtime_parameters + @classmethod + def _extract_runtime_selector_value(cls, parameter: ToolParameter, config: Any) -> dict[str, Any] | None: + if parameter.type not in { + ToolParameter.ToolParameterType.MODEL_SELECTOR, + ToolParameter.ToolParameterType.APP_SELECTOR, + }: + return None + if not isinstance(config, dict): + return None + + input_value = config.get("value") + if isinstance(input_value, dict) and cls._is_selector_value(parameter, input_value): + return cast("dict[str, Any]", parameter.init_frontend_parameter(input_value)) + + if cls._is_selector_value(parameter, config): + selector_value = dict(config) + selector_value.pop("type", None) + selector_value.pop("value", None) + return cast("dict[str, Any]", parameter.init_frontend_parameter(selector_value)) + + return None + + @classmethod + def _is_selector_value(cls, parameter: ToolParameter, value: Mapping[str, Any]) -> bool: + if parameter.type == ToolParameter.ToolParameterType.MODEL_SELECTOR: + return ( + isinstance(value.get("provider"), str) + and isinstance(value.get("model"), str) + and isinstance(value.get("model_type"), str) + ) + if parameter.type == ToolParameter.ToolParameterType.APP_SELECTOR: + return isinstance(value.get("app_id"), str) + return False + ToolManager.load_hardcoded_providers_cache() diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 03e3c5918d..b6890b2611 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,7 +1,6 @@ import threading from flask import Flask, current_app -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select @@ -15,6 +14,7 @@ from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model: DefaultRetrievalModelDict = { diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 6a189fa6aa..0d1dc7273b 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from pydantic import BaseModel, Field from sqlalchemy import select @@ -39,7 +39,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): dataset_id: str user_id: str | None = None retrieve_config: DatasetRetrieveConfigEntity - inputs: dict + inputs: dict[str, Any] @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index fca6e6f1c7..0bdc3df869 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -33,7 +33,7 @@ class DatasetRetrieverTool(Tool): invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, user_id: str, - inputs: dict, + inputs: dict[str, Any], ) -> list["DatasetRetrieverTool"]: """ get dataset tool diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 2264981abd..5679466cbc 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -4,15 +4,16 @@ from collections.abc import Generator from datetime import date, datetime from decimal import Decimal from mimetypes import guess_extension +from typing import Any from uuid import UUID import numpy as np import pytz -from graphon.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account @@ -40,6 +41,10 @@ def safe_json_value(v): return v.hex() elif isinstance(v, memoryview): return v.tobytes().hex() + elif isinstance(v, np.integer): + return int(v) + elif isinstance(v, np.floating): + return float(v) elif isinstance(v, np.ndarray): return v.tolist() elif isinstance(v, dict): @@ -50,7 +55,7 @@ def safe_json_value(v): return v -def safe_json_dict(d: dict): +def safe_json_dict(d: dict[str, Any]): if not isinstance(d, dict): raise TypeError("safe_json_dict() expects a dictionary (dict) as input") return {k: safe_json_value(v) for k, v in d.items()} @@ -196,11 +201,11 @@ class ToolFileMessageTransformer: @staticmethod def _with_tool_file_meta( - meta: dict | None, + meta: dict[str, Any] | None, *, tool_file_id: str | None = None, url: str | None = None, - ) -> dict: + ) -> dict[str, Any]: normalized_meta = meta.copy() if meta is not None else {} resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url) if resolved_tool_file_id and "tool_file_id" not in normalized_meta: diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 8d6f83dc07..a3623d4ecd 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -8,6 +8,9 @@ import json from decimal import Decimal from typing import cast +from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType +from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMResult from graphon.model_runtime.entities.message_entities import PromptMessage from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType @@ -18,12 +21,8 @@ from graphon.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from graphon.model_runtime.utils.encoders import jsonable_encoder - -from core.model_manager import ModelManager -from core.tools.entities.tool_entities import ToolProviderType -from extensions.ext_database import db from models.tools import ToolModelInvoke diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index f7484b93fb..434af55583 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -32,7 +32,7 @@ class OpenAPISpecDict(TypedDict): class ApiBasedToolSchemaParser: @staticmethod def parse_openapi_to_tool_bundle( - openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None + openapi: Mapping[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -236,7 +236,7 @@ class ApiBasedToolSchemaParser: return value @staticmethod - def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: + def _get_tool_parameter_type(parameter: dict[str, Any]) -> ToolParameter.ToolParameterType | None: parameter = parameter or {} typ: str | None = None if parameter.get("format") == "binary": @@ -265,7 +265,7 @@ class ApiBasedToolSchemaParser: @staticmethod def parse_openapi_yaml_to_tool_bundle( - yaml: str, extra_info: dict | None = None, warning: dict | None = None + yaml: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None ) -> list[ApiToolBundle]: """ parse openapi yaml to tool bundle @@ -278,14 +278,14 @@ class ApiBasedToolSchemaParser: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} - openapi: dict = safe_load(yaml) + openapi: dict[str, Any] = safe_load(yaml) if openapi is None: raise ToolApiSchemaError("Invalid openapi yaml.") return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) @staticmethod def parse_swagger_to_openapi( - swagger: dict, extra_info: dict | None = None, warning: dict | None = None + swagger: dict[str, Any], extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None ) -> OpenAPISpecDict: warning = warning or {} """ @@ -351,7 +351,7 @@ class ApiBasedToolSchemaParser: @staticmethod def parse_openai_plugin_json_to_tool_bundle( - json: str, extra_info: dict | None = None, warning: dict | None = None + json: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None ) -> list[ApiToolBundle]: """ parse openapi plugin yaml to tool bundle @@ -392,7 +392,7 @@ class ApiBasedToolSchemaParser: @staticmethod def auto_parse_to_tool_bundle( - content: str, extra_info: dict | None = None, warning: dict | None = None + content: str, extra_info: dict[str, Any] | None = None, warning: dict[str, Any] | None = None ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ auto parse to tool bundle diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_encryption.py similarity index 57% rename from api/core/tools/utils/system_oauth_encryption.py rename to api/core/tools/utils/system_encryption.py index 6b7007842d..ca7e6a13fe 100644 --- a/api/core/tools/utils/system_oauth_encryption.py +++ b/api/core/tools/utils/system_encryption.py @@ -14,23 +14,23 @@ from configs import dify_config logger = logging.getLogger(__name__) -class OAuthEncryptionError(Exception): - """OAuth encryption/decryption specific error""" +class EncryptionError(Exception): + """Encryption/decryption specific error""" pass -class SystemOAuthEncrypter: +class SystemEncrypter: """ - A simple OAuth parameters encrypter using AES-CBC encryption. + A simple parameters encrypter using AES-CBC encryption. - This class provides methods to encrypt and decrypt OAuth parameters + This class provides methods to encrypt and decrypt parameters using AES-CBC mode with a key derived from the application's SECRET_KEY. """ def __init__(self, secret_key: str | None = None): """ - Initialize the OAuth encrypter. + Initialize the encrypter. Args: secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY @@ -43,19 +43,19 @@ class SystemOAuthEncrypter: # Generate a fixed 256-bit key using SHA-256 self.key = hashlib.sha256(secret_key.encode()).digest() - def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str: + def encrypt_params(self, params: Mapping[str, Any]) -> str: """ - Encrypt OAuth parameters. + Encrypt parameters. Args: - oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} + params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} Returns: Base64-encoded encrypted string Raises: - OAuthEncryptionError: If encryption fails - ValueError: If oauth_params is invalid + EncryptionError: If encryption fails + ValueError: If params is invalid """ try: @@ -66,7 +66,7 @@ class SystemOAuthEncrypter: cipher = AES.new(self.key, AES.MODE_CBC, iv) # Encrypt data - padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size) + padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size) encrypted_data = cipher.encrypt(padded_data) # Combine IV and encrypted data @@ -76,20 +76,20 @@ class SystemOAuthEncrypter: return base64.b64encode(combined).decode() except Exception as e: - raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e + raise EncryptionError(f"Encryption failed: {str(e)}") from e - def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]: + def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]: """ - Decrypt OAuth parameters. + Decrypt parameters. Args: encrypted_data: Base64-encoded encrypted string Returns: - Decrypted OAuth parameters dictionary + Decrypted parameters dictionary Raises: - OAuthEncryptionError: If decryption fails + EncryptionError: If decryption fails ValueError: If encrypted_data is invalid """ if not isinstance(encrypted_data, str): @@ -118,70 +118,70 @@ class SystemOAuthEncrypter: unpadded_data = unpad(decrypted_data, AES.block_size) # Parse JSON - oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) + params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) - if not isinstance(oauth_params, dict): + if not isinstance(params, dict): raise ValueError("Decrypted data is not a valid dictionary") - return oauth_params + return params except Exception as e: - raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e + raise EncryptionError(f"Decryption failed: {str(e)}") from e # Factory function for creating encrypter instances -def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter: +def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter: """ - Create an OAuth encrypter instance. + Create an encrypter instance. Args: secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY Returns: - SystemOAuthEncrypter instance + SystemEncrypter instance """ - return SystemOAuthEncrypter(secret_key=secret_key) + return SystemEncrypter(secret_key=secret_key) # Global encrypter instance (for backward compatibility) -_oauth_encrypter: SystemOAuthEncrypter | None = None +_encrypter: SystemEncrypter | None = None -def get_system_oauth_encrypter() -> SystemOAuthEncrypter: +def get_system_encrypter() -> SystemEncrypter: """ - Get the global OAuth encrypter instance. + Get the global encrypter instance. Returns: - SystemOAuthEncrypter instance + SystemEncrypter instance """ - global _oauth_encrypter - if _oauth_encrypter is None: - _oauth_encrypter = SystemOAuthEncrypter() - return _oauth_encrypter + global _encrypter + if _encrypter is None: + _encrypter = SystemEncrypter() + return _encrypter # Convenience functions for backward compatibility -def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str: +def encrypt_system_params(params: Mapping[str, Any]) -> str: """ - Encrypt OAuth parameters using the global encrypter. + Encrypt parameters using the global encrypter. Args: - oauth_params: OAuth parameters dictionary + params: Parameters dictionary Returns: Base64-encoded encrypted string """ - return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) + return get_system_encrypter().encrypt_params(params) -def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]: +def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]: """ - Decrypt OAuth parameters using the global encrypter. + Decrypt parameters using the global encrypter. Args: encrypted_data: Base64-encoded encrypted string Returns: - Decrypted OAuth parameters dictionary + Decrypted parameters dictionary """ - return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data) + return get_system_encrypter().decrypt_params(encrypted_data) diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index ed3ed3e0de..94a2c0427b 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -105,7 +105,7 @@ class Article: def extract_using_readabilipy(html: str): - json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True) + json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False) article = Article( title=json_article.get("title") or "", author=json_article.get("byline") or "", diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index c4b7d57449..45718cadb6 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,13 +1,12 @@ from collections.abc import Mapping, Sequence from typing import Any +from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration +from core.tools.errors import WorkflowToolHumanInputNotSupportedError from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.entities import OutputVariableEntity from graphon.variables.input_entities import VariableEntity -from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration -from core.tools.errors import WorkflowToolHumanInputNotSupportedError - class WorkflowToolConfigurationUtils: @classmethod @@ -17,10 +16,8 @@ class WorkflowToolConfigurationUtils: """ nodes = graph.get("nodes", []) start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) - if not start_node: return [] - return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])] @classmethod diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index a01004448a..5905fd919e 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Mapping -from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field from sqlalchemy import select from sqlalchemy.orm import Session @@ -25,6 +24,7 @@ from core.tools.entities.tool_entities import ( from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index a17b7f108d..cd8c6352b5 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,8 +5,6 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController @@ -22,6 +20,8 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolInvokeError from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from models import Account, Tenant from models.model import App, EndUser from models.utils.file_input_compat import build_file_from_stored_mapping @@ -277,7 +277,7 @@ class WorkflowTool(Tool): session.expunge(app) return app - def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: + def _transform_args(self, tool_parameters: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, str | None]]]: """ transform the tool parameters @@ -323,7 +323,7 @@ class WorkflowTool(Tool): return parameters_result, files - def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]: + def _extract_files(self, outputs: dict[str, Any]) -> tuple[dict[str, Any], list[File]]: """ extract files from the result @@ -355,9 +355,12 @@ class WorkflowTool(Tool): return result, files - def _update_file_mapping(self, file_dict: dict): + def _update_file_mapping(self, file_dict: dict[str, Any]) -> dict[str, Any]: file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) - transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) + transfer_method_value = file_dict.get("transfer_method") + if not isinstance(transfer_method_value, str): + raise ValueError("Workflow file mapping is missing a valid transfer_method") + transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.TOOL_FILE: file_dict["tool_file_id"] = file_id diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 61d1cd8540..24c1271488 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -8,7 +8,6 @@ from collections.abc import Mapping from datetime import datetime from typing import Any -from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse @@ -28,6 +27,7 @@ from core.trigger.debug.events import ( from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at from models.model import App diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_adapter.py similarity index 67% rename from api/core/workflow/human_input_compat.py rename to api/core/workflow/human_input_adapter.py index c95516a240..731ae2b858 100644 --- a/api/core/workflow/human_input_compat.py +++ b/api/core/workflow/human_input_adapter.py @@ -1,8 +1,8 @@ -"""Workflow-layer adapters for legacy human-input payload keys. +"""Workflow-to-Graphon adapters for persisted node payloads. -Stored workflow graphs and editor payloads may still use Dify-specific human -input recipient keys. Normalize them here before handing configs to -`graphon` so graph-owned models only see graph-neutral field names. +Stored workflow graphs and editor payloads still contain a small set of +Dify-owned field spellings and value shapes. Adapt them here before handing the +payload to Graphon so Graphon-owned models only see current contracts. """ from __future__ import annotations @@ -14,12 +14,13 @@ from typing import Annotated, Any, ClassVar, Literal import bleach import markdown +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter + from graphon.enums import BuiltinNodeTypes from graphon.nodes.base.variable_template_parser import VariableTemplateParser from graphon.runtime import VariablePool from graphon.variables.consts import SELECTORS_LENGTH -from markdown.extensions.tables import TableExtension -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter class DeliveryMethodType(enum.StrEnum): @@ -184,7 +185,7 @@ def _copy_mapping(value: object) -> dict[str, Any] | None: return None -def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") @@ -214,7 +215,7 @@ def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | Bas def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: - normalized = normalize_human_input_node_data_for_graph(node_data) + normalized = adapt_human_input_node_data_for_graph(node_data) raw_delivery_methods = normalized.get("delivery_methods") if not isinstance(raw_delivery_methods, list): return [] @@ -228,17 +229,20 @@ def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> b return False -def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_data) if normalized is None: raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") - if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: - return normalized - return normalize_human_input_node_data_for_graph(normalized) + node_type = normalized.get("type") + if node_type == BuiltinNodeTypes.HUMAN_INPUT: + return adapt_human_input_node_data_for_graph(normalized) + if node_type == BuiltinNodeTypes.TOOL: + return _adapt_tool_node_data_for_graph(normalized) + return normalized -def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: +def adapt_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: normalized = _copy_mapping(node_config) if normalized is None: raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") @@ -247,10 +251,95 @@ def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) if data_mapping is None: return normalized - normalized["data"] = normalize_node_data_for_graph(data_mapping) + normalized["data"] = adapt_node_data_for_graph(data_mapping) return normalized +def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(node_data) + + raw_tool_configurations = normalized.get("tool_configurations") + if not isinstance(raw_tool_configurations, Mapping): + return normalized + + existing_tool_parameters = normalized.get("tool_parameters") + normalized_tool_parameters = dict(existing_tool_parameters) if isinstance(existing_tool_parameters, Mapping) else {} + normalized_tool_configurations: dict[str, Any] = {} + found_legacy_tool_inputs = False + + for name, value in raw_tool_configurations.items(): + if not isinstance(value, Mapping): + normalized_tool_configurations[name] = value + continue + + selector_value = _extract_selector_configuration(value) + if selector_value is not None: + # Model/app selectors are dictionaries even when they come through the legacy tool configuration path. + # Move them to tool_parameters so graph validation does not flatten them as primitive constants. + found_legacy_tool_inputs = True + normalized_tool_parameters.setdefault(name, {"type": "constant", "value": selector_value}) + continue + + input_type = value.get("type") + input_value = value.get("value") + if input_type not in {"mixed", "variable", "constant"}: + normalized_tool_configurations[name] = value + continue + + found_legacy_tool_inputs = True + normalized_tool_parameters.setdefault(name, dict(value)) + + flattened_value = _flatten_legacy_tool_configuration_value( + input_type=input_type, + input_value=input_value, + ) + if flattened_value is not None: + normalized_tool_configurations[name] = flattened_value + + if not found_legacy_tool_inputs: + return normalized + + normalized["tool_parameters"] = normalized_tool_parameters + normalized["tool_configurations"] = normalized_tool_configurations + return normalized + + +def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: Any) -> str | int | float | bool | None: + if input_type in {"mixed", "constant"} and isinstance(input_value, str | int | float | bool): + return input_value + + if ( + input_type == "variable" + and isinstance(input_value, list) + and all(isinstance(item, str) for item in input_value) + ): + return "{{#" + ".".join(input_value) + "#}}" + + return None + + +def _extract_selector_configuration(value: Mapping[str, Any]) -> dict[str, Any] | None: + input_value = value.get("value") + if isinstance(input_value, Mapping) and _is_selector_configuration(input_value): + return dict(input_value) + + if _is_selector_configuration(value): + selector_value = dict(value) + selector_value.pop("type", None) + selector_value.pop("value", None) + return selector_value + + return None + + +def _is_selector_configuration(value: Mapping[str, Any]) -> bool: + return ( + isinstance(value.get("provider"), str) + and isinstance(value.get("model"), str) + and isinstance(value.get("model_type"), str) + ) or isinstance(value.get("app_id"), str) + + def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: normalized = dict(recipients) @@ -290,9 +379,9 @@ __all__ = [ "MemberRecipient", "WebAppDeliveryMethod", "_WebAppDeliveryConfig", + "adapt_human_input_node_data_for_graph", + "adapt_node_config_for_graph", + "adapt_node_data_for_graph", "is_human_input_webapp_enabled", - "normalize_human_input_node_data_for_graph", - "normalize_node_config_for_graph", - "normalize_node_data_for_graph", "parse_human_input_delivery_methods", ] diff --git a/api/core/workflow/human_input_forms.py b/api/core/workflow/human_input_forms.py index f124b321d4..b02f69ec33 100644 --- a/api/core/workflow/human_input_forms.py +++ b/api/core/workflow/human_input_forms.py @@ -12,20 +12,16 @@ from collections.abc import Sequence from sqlalchemy import select from sqlalchemy.orm import Session +from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token from extensions.ext_database import db from models.human_input import HumanInputFormRecipient, RecipientType -_FORM_TOKEN_PRIORITY = { - RecipientType.BACKSTAGE: 0, - RecipientType.CONSOLE: 1, - RecipientType.STANDALONE_WEB_APP: 2, -} - def load_form_tokens_by_form_id( form_ids: Sequence[str], *, session: Session | None = None, + surface: HumanInputSurface | None = None, ) -> dict[str, str]: """Load the preferred access token for each human input form.""" unique_form_ids = list(dict.fromkeys(form_ids)) @@ -33,23 +29,43 @@ def load_form_tokens_by_form_id( return {} if session is not None: - return _load_form_tokens_by_form_id(session, unique_form_ids) + return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface) with Session(bind=db.engine, expire_on_commit=False) as new_session: - return _load_form_tokens_by_form_id(new_session, unique_form_ids) + return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface) -def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]: - tokens_by_form_id: dict[str, tuple[int, str]] = {} +def _load_form_tokens_by_form_id( + session: Session, + form_ids: Sequence[str], + *, + surface: HumanInputSurface | None = None, +) -> dict[str, str]: + recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {} stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) for recipient in session.scalars(stmt): - priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type) - if priority is None or not recipient.access_token: + if not recipient.access_token: continue + recipients_by_form_id.setdefault(recipient.form_id, []).append( + (recipient.recipient_type, recipient.access_token) + ) - candidate = (priority, recipient.access_token) - current = tokens_by_form_id.get(recipient.form_id) - if current is None or candidate[0] < current[0]: - tokens_by_form_id[recipient.form_id] = candidate + tokens_by_form_id: dict[str, str] = {} + for form_id, recipients in recipients_by_form_id.items(): + token = _get_surface_form_token(recipients, surface=surface) + if token is not None: + tokens_by_form_id[form_id] = token + return tokens_by_form_id - return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()} + +def _get_surface_form_token( + recipients: Sequence[tuple[RecipientType, str]], + *, + surface: HumanInputSurface | None, +) -> str | None: + if surface == HumanInputSurface.SERVICE_API: + for recipient_type, token in recipients: + if recipient_type == RecipientType.STANDALONE_WEB_APP and token: + return token + + return get_preferred_form_token(recipients) diff --git a/api/core/workflow/human_input_policy.py b/api/core/workflow/human_input_policy.py new file mode 100644 index 0000000000..798eb8723f --- /dev/null +++ b/api/core/workflow/human_input_policy.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any + +from graphon.entities.pause_reason import PauseReasonType +from models.human_input import RecipientType + + +class HumanInputSurface(StrEnum): + SERVICE_API = "service_api" + CONSOLE = "console" + + +# Service API is intentionally narrower than other surfaces: app-token callers +# should only be able to act on end-user web forms, not internal console flows. +_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = { + HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}), + HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}), +} + +# A single HITL form can have multiple recipient records; this shared priority +# keeps every API surface consistent about which resume token to expose. +_RECIPIENT_TOKEN_PRIORITY: dict[RecipientType, int] = { + RecipientType.BACKSTAGE: 0, + RecipientType.CONSOLE: 1, + RecipientType.STANDALONE_WEB_APP: 2, +} + + +def is_recipient_type_allowed_for_surface( + recipient_type: RecipientType | None, + surface: HumanInputSurface, +) -> bool: + if recipient_type is None: + return False + return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface] + + +def get_preferred_form_token( + recipients: Sequence[tuple[RecipientType, str]], +) -> str | None: + chosen_token: str | None = None + chosen_priority: int | None = None + for recipient_type, token in recipients: + priority = _RECIPIENT_TOKEN_PRIORITY.get(recipient_type) + if priority is None or not token: + continue + if chosen_priority is None or priority < chosen_priority: + chosen_priority = priority + chosen_token = token + return chosen_token + + +def enrich_human_input_pause_reasons( + reasons: Sequence[Mapping[str, Any]], + *, + form_tokens_by_form_id: Mapping[str, str], + expiration_times_by_form_id: Mapping[str, int], +) -> list[dict[str, Any]]: + enriched: list[dict[str, Any]] = [] + for reason in reasons: + updated = dict(reason) + if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED: + form_id = updated.get("form_id") + if isinstance(form_id, str): + updated["form_token"] = form_tokens_by_form_id.get(form_id) + expiration_time = expiration_times_by_form_id.get(form_id) + if expiration_time is not None: + updated["expiration_time"] = expiration_time + enriched.append(updated) + return enriched diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index b04ac7da3d..895953a3c1 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -5,22 +5,6 @@ from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, Any, cast, final, override -from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes, NodeType -from graphon.file.file_manager import file_manager -from graphon.graph.graph import NodeFactory -from graphon.model_runtime.memory import PromptMessageMemory -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.nodes.base.node import Node -from graphon.nodes.code.code_node import WorkflowCodeExecutor -from graphon.nodes.code.entities import CodeLanguage -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.nodes.document_extractor import UnstructuredApiConfig -from graphon.nodes.http_request import build_http_request_config -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from sqlalchemy import select from sqlalchemy.orm import Session @@ -31,12 +15,12 @@ from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, ) -from core.helper.ssrf_proxy import ssrf_proxy +from core.helper.ssrf_proxy import graphon_ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.trigger.constants import TRIGGER_NODE_TYPES -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.node_runtime import ( DifyFileReferenceFactory, DifyHumanInputNodeRuntime, @@ -56,6 +40,22 @@ from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.file_manager import file_manager +from graphon.graph.graph import NodeFactory +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.nodes.base.node import Node +from graphon.nodes.code.code_node import WorkflowCodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.nodes.document_extractor import UnstructuredApiConfig +from graphon.nodes.http_request import build_http_request_config +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from models.model import Conversation if TYPE_CHECKING: @@ -121,6 +121,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node] def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + """Resolve the production node class for the requested type/version.""" node_mapping = get_node_type_classes_mapping().get(node_type) if not node_mapping: raise ValueError(f"No class mapping found for node type: {node_type}") @@ -297,7 +298,7 @@ class DifyNodeFactory(NodeFactory): ) self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH - self._http_request_http_client = ssrf_proxy + self._http_request_http_client = graphon_ssrf_proxy self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( self._dify_context, conversation_id_getter=self._conversation_id, @@ -364,10 +365,20 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + adapted_node_config = adapt_node_config_for_graph(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(adapted_node_config) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) + # Graph configs are initially validated against permissive shared node data. + # Re-validate using the resolved node class so workflow-local node schemas + # stay explicit and constructors receive the concrete typed payload. + resolved_node_data = self._validate_resolved_node_data(node_class, node_data) + config_for_node_init: BaseNodeData | dict[str, Any] + if isinstance(resolved_node_data, BaseNodeData): + config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True) + else: + config_for_node_init = resolved_node_data node_type = node_data.type node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { BuiltinNodeTypes.CODE: lambda: { @@ -391,7 +402,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -405,7 +416,7 @@ class DifyNodeFactory(NodeFactory): }, BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=True, include_llm_file_saver=True, @@ -415,7 +426,7 @@ class DifyNodeFactory(NodeFactory): ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, - node_data=node_data, + node_data=resolved_node_data, wrap_model_instance=True, include_http_client=False, include_llm_file_saver=False, @@ -436,8 +447,8 @@ class DifyNodeFactory(NodeFactory): } node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() return node_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=config_for_node_init, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, **node_init_kwargs, @@ -448,7 +459,10 @@ class DifyNodeFactory(NodeFactory): """ Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class. """ - return node_class.validate_node_data(node_data) + validate_node_data = getattr(node_class, "validate_node_data", None) + if callable(validate_node_data): + return cast("BaseNodeData", validate_node_data(node_data)) + return node_data @staticmethod def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: @@ -466,10 +480,7 @@ class DifyNodeFactory(NodeFactory): include_retriever_attachment_loader: bool, include_jinja2_template_renderer: bool, ) -> dict[str, object]: - validated_node_data = cast( - LLMCompatibleNodeData, - self._validate_resolved_node_data(node_class=node_class, node_data=node_data), - ) + validated_node_data = cast(LLMCompatibleNodeData, node_data) model_instance = self._build_model_instance_for_llm_node(validated_node_data) node_init_kwargs: dict[str, object] = { "credentials_provider": self._llm_credentials_provider, diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 19cb3a7b0a..c1d3a856fb 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -2,38 +2,8 @@ from __future__ import annotations from collections.abc import Callable, Generator, Mapping, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast, overload -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities import LLMMode -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from graphon.model_runtime.entities.model_entities import AIModelEntity -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.nodes.llm.runtime_protocols import ( - PreparedLLMProtocol, - PromptMessageSerializerProtocol, - RetrieverAttachmentLoaderProtocol, -) -from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol -from graphon.nodes.runtime import ( - HumanInputFormStateProtocol, - HumanInputNodeRuntimeProtocol, - ToolNodeRuntimeProtocol, -) -from graphon.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError -from graphon.nodes.tool_runtime_entities import ( - ToolRuntimeHandle, - ToolRuntimeMessage, - ToolRuntimeParameter, -) from sqlalchemy import select from sqlalchemy.orm import Session @@ -60,11 +30,41 @@ from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories import file_factory +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities import LLMMode +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.runtime import ( + HumanInputFormStateProtocol, + HumanInputNodeRuntimeProtocol, + ToolNodeRuntimeProtocol, +) +from graphon.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) from models.dataset import SegmentAttachmentBinding from models.model import UploadFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from .human_input_compat import ( +from .human_input_adapter import ( BoundRecipient, DeliveryChannelConfig, DeliveryMethodType, @@ -76,13 +76,12 @@ from .human_input_compat import ( from .system_variables import SystemVariableKey, get_system_text if TYPE_CHECKING: + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage from graphon.file import File from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.tool.entities import ToolNodeData - from core.tools.__base.tool import Tool - from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage - _file_access_controller = DatabaseFileAccessController() @@ -174,6 +173,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: return self._model_instance.get_llm_num_tokens(prompt_messages) + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunk, None, None]: ... + def invoke_llm( self, *, @@ -191,6 +212,28 @@ class DifyPreparedLLM(PreparedLLMProtocol): stream=stream, ) + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResultWithStructuredOutput: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + def invoke_llm_with_structured_output( self, *, @@ -458,11 +501,15 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): @staticmethod def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec: + tool_configurations = dict(node_data.tool_configurations) + tool_configurations.update( + {name: tool_input.model_dump(mode="python") for name, tool_input in node_data.tool_parameters.items()} + ) return _WorkflowToolRuntimeSpec( provider_type=CoreToolProviderType(node_data.provider_type.value), provider_id=node_data.provider_id, tool_name=node_data.tool_name, - tool_configurations=dict(node_data.tool_configurations), + tool_configurations=tool_configurations, credential_id=node_data.credential_id, ) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index bfd5536e4a..68a24e86b1 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,15 +3,13 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.workflow.system_variables import SystemVariableKey, get_system_text - from .entities import AgentNodeData from .exceptions import ( AgentInvocationError, @@ -36,18 +34,18 @@ class AgentNode(Node[AgentNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: AgentNodeData, + *, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, - *, strategy_resolver: AgentStrategyResolver, presentation_provider: AgentStrategyPresentationProvider, runtime_support: AgentRuntimeSupport, message_transformer: AgentMessageTransformer, ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index c52aad150b..51452c29a3 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,12 +1,12 @@ from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index db74590ed7..f44681377d 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -3,6 +3,14 @@ from __future__ import annotations from collections.abc import Generator, Mapping from typing import Any, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import DatabaseFileAccessController +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from extensions.ext_database import db +from factories import file_factory from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata @@ -15,14 +23,6 @@ from graphon.node_events import ( StreamCompletedEvent, ) from graphon.variables.segments import ArrayFileSegment -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.file_access import DatabaseFileAccessController -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from extensions.ext_database import db -from factories import file_factory from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index be50edbc4d..a872774c98 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -4,8 +4,6 @@ import json from collections.abc import Sequence from typing import Any, cast -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.runtime import VariablePool from packaging.version import Version from pydantic import ValidationError from sqlalchemy import select @@ -21,6 +19,8 @@ from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolP from core.tools.tool_manager import ToolManager from core.workflow.system_variables import SystemVariableKey, get_system_text from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.runtime import VariablePool from models.model import Conversation from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index d9247b2593..f3006c4242 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,7 +1,12 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.datasource.datasource_manager import DatasourceManager +from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.plugin.impl.exc import PluginDaemonClientSideError +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment from graphon.enums import ( BuiltinNodeTypes, NodeExecutionType, @@ -12,13 +17,6 @@ from graphon.node_events import NodeRunResult, StreamCompletedEvent from graphon.nodes.base.node import Node from graphon.nodes.base.variable_template_parser import VariableTemplateParser -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.datasource.datasource_manager import DatasourceManager -from core.datasource.entities.datasource_entities import DatasourceProviderType -from core.plugin.impl.exc import PluginDaemonClientSideError -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.system_variables import SystemVariableKey, get_system_segment - from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError @@ -37,13 +35,14 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: DatasourceNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index cad32f8d5b..28966f2392 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,9 +1,10 @@ from typing import Any, Literal, Union +from pydantic import BaseModel, field_validator +from pydantic_core.core_schema import ValidationInfo + from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo class DatasourceEntity(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index f8e239d250..260881e49c 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,22 +1,13 @@ from typing import Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel -from core.rag.entities.retrieval_settings import WeightedScoreConfig +from core.rag.entities import RerankingModelConfig, WeightedScoreConfig from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE - - -class RerankingModelConfig(BaseModel): - """ - Reranking Model Config. - """ - - reranking_provider_name: str - reranking_model_name: str +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class RetrievalSetting(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index bb72fe3881..9c1b7ab2c4 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,17 +2,15 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from graphon.entities.graph_config import NodeConfigDict -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.base.template import Template - from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( @@ -33,12 +31,18 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeIndexNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: - super().__init__(id, config, graph_init_params, graph_runtime_state) + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) self.index_processor = IndexProcessor() self.summary_index_service = SummaryIndex() diff --git a/api/core/workflow/nodes/knowledge_index/protocols.py b/api/core/workflow/nodes/knowledge_index/protocols.py index 6668f0c98e..d04e79c2a8 100644 --- a/api/core/workflow/nodes/knowledge_index/protocols.py +++ b/api/core/workflow/nodes/knowledge_index/protocols.py @@ -43,15 +43,20 @@ class IndexProcessorProtocol(Protocol): original_document_id: str, chunks: Mapping[str, Any], batch: Any, - summary_index_setting: dict | None = None, + summary_index_setting: dict[str, Any] | None = None, ) -> IndexingResultDict: ... def get_preview_output( - self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None + self, + chunks: Any, + dataset_id: str, + document_id: str, + chunk_structure: str, + summary_index_setting: dict[str, Any] | None, ) -> Preview: ... class SummaryIndexServiceProtocol(Protocol): def generate_and_vectorize_summary( - self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None + self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict[str, Any] | None = None ) -> None: ... diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index f4bc3fb9d3..3825f526a2 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,24 +1,15 @@ from typing import Literal +from pydantic import BaseModel, Field + +from core.rag.entities import Condition, MetadataFilteringCondition, RerankingModelConfig, WeightedScoreConfig from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.llm.entities import ModelConfig, VisionConfig -from pydantic import BaseModel, Field - -from core.rag.entities import Condition, MetadataFilteringCondition, WeightedScoreConfig __all__ = ["Condition"] -class RerankingModelConfig(BaseModel): - """ - Reranking Model Config. - """ - - provider: str - model: str - - class MultipleRetrievalConfig(BaseModel): """ Multiple Retrieval Config. diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 13624b27b3..25f73e446d 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,8 +8,12 @@ import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, @@ -27,12 +31,6 @@ from graphon.variables import ( ) from graphon.variables.segments import ArrayObjectSegment -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file_reference import parse_file_reference - from .entities import ( Condition, KnowledgeRetrievalNodeData, @@ -51,6 +49,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _normalize_metadata_filter_scalar(value: object) -> str | int | float | None: + if value is None or isinstance(value, (str, float)): + return value + if isinstance(value, int) and not isinstance(value, bool): + return value + return str(value) + + +def _normalize_metadata_filter_sequence_item(value: object) -> str: + return value if isinstance(value, str) else str(value) + + class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL @@ -60,13 +70,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def __init__( self, - id: str, - config: NodeConfigDict, + node_id: str, + config: KnowledgeRetrievalNodeData, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - ): + ) -> None: super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, @@ -283,18 +294,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD resolved_conditions: list[Condition] = [] for cond in conditions.conditions or []: value = cond.value + resolved_value: str | Sequence[str] | int | float | None if isinstance(value, str): segment_group = variable_pool.convert_template(value) if len(segment_group.value) == 1: - resolved_value = segment_group.value[0].to_object() + resolved_value = _normalize_metadata_filter_scalar(segment_group.value[0].to_object()) else: resolved_value = segment_group.text elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value): - resolved_values = [] - for v in value: # type: ignore + resolved_values: list[str] = [] + for v in value: segment_group = variable_pool.convert_template(v) if len(segment_group.value) == 1: - resolved_values.append(segment_group.value[0].to_object()) + resolved_values.append( + _normalize_metadata_filter_sequence_item(segment_group.value[0].to_object()) + ) else: resolved_values.append(segment_group.text) resolved_value = resolved_values diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index 39e2008a2c..ea45dcf5c2 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Protocol -from graphon.model_runtime.entities import LLMUsage -from graphon.nodes.llm.entities import ModelConfig from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from graphon.model_runtime.entities import LLMUsage +from graphon.nodes.llm.entities import ModelConfig from .entities import MetadataFilteringCondition diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index bf5be2379a..23ed2cd408 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index e50de11bb9..c848a86255 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,13 +1,12 @@ from collections.abc import Mapping from typing import Any +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node -from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID - from .entities import TriggerEventNodeData diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index f14ca893c9..683c8d420f 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -1,10 +1,10 @@ -from typing import Literal, Union +from typing import Any, Literal, Union -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): @@ -16,7 +16,7 @@ class TriggerScheduleNodeData(BaseNodeData): mode: str = Field(default="visual", description="Schedule mode: visual or cron") frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly") cron_expression: str | None = Field(default=None, description="Cron expression for cron mode") - visual_config: dict | None = Field(default=None, description="Visual configuration details") + visual_config: dict[str, Any] | None = Field(default=None, description="Visual configuration details") timezone: str = Field(default="UTC", description="Timezone for schedule execution") diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index a9753ab387..b46cc76a6e 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,10 @@ from collections.abc import Mapping -from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node - from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node from .entities import TriggerScheduleNodeData diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index a30f877e4b..b261039448 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from enum import StrEnum -from graphon.entities.base_node_data import BaseNodeData -from graphon.enums import NodeType -from graphon.variables.types import SegmentType from pydantic import BaseModel, Field, field_validator from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType +from graphon.variables.types import SegmentType _WEBHOOK_HEADER_ALLOWED_TYPES: frozenset[SegmentType] = frozenset((SegmentType.STRING,)) diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 8c866aea81..13c4f05bfd 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,6 +2,10 @@ import logging from collections.abc import Mapping from typing import Any +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment_with_type from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus from graphon.file import FileTransferMethod from graphon.node_events import NodeRunResult @@ -10,11 +14,6 @@ from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.variables.types import SegmentType from graphon.variables.variables import FileVariable -from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from core.workflow.file_reference import resolve_file_record_id -from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID -from factories.variable_factory import build_segment_with_type - from .entities import ContentType, WebhookData logger = logging.getLogger(__name__) @@ -75,7 +74,7 @@ class TriggerWebhookNode(Node[WebhookData]): outputs=outputs, ) - def generate_file_var(self, param_name: str, file: dict): + def generate_file_var(self, param_name: str, file: dict[str, Any]): file_id = resolve_file_record_id(file.get("reference") or file.get("related_id")) transfer_method_value = file.get("transfer_method") if transfer_method_value: @@ -147,7 +146,7 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = str(webhook_data.get("body", {}).get("raw", "")) continue elif self.node_data.content_type == ContentType.BINARY: - raw_data: dict = webhook_data.get("body", {}).get("raw", {}) + raw_data: dict[str, Any] = webhook_data.get("body", {}).get("raw", {}) file_var = self.generate_file_var(param_name, raw_data) if file_var: outputs[param_name] = file_var diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py index d51cfadd09..b4ffb37549 100644 --- a/api/core/workflow/template_rendering.py +++ b/api/core/workflow/template_rendering.py @@ -3,11 +3,10 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor from graphon.nodes.code.entities import CodeLanguage from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor - class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index f0a5fbb400..4e2f603e5b 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -3,20 +3,6 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import Any, TypedDict -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.errors import WorkflowNodeRunFailedError -from graphon.file import File -from graphon.graph import Graph -from graphon.graph_engine import GraphEngine, GraphEngineConfig -from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel -from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool - from configs import dify_config from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError @@ -40,6 +26,19 @@ from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel +from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow logger = logging.getLogger(__name__) diff --git a/api/dev/generate_swagger_specs.py b/api/dev/generate_swagger_specs.py new file mode 100644 index 0000000000..7e9688bfb4 --- /dev/null +++ b/api/dev/generate_swagger_specs.py @@ -0,0 +1,172 @@ +"""Generate Flask-RESTX Swagger 2.0 specs without booting the full backend. + +This helper intentionally avoids `app_factory.create_app()`. The normal backend +startup eagerly initializes database, Redis, Celery, and storage extensions, +which is unnecessary when the goal is only to serialize the Flask-RESTX +`/swagger.json` documents. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +from dataclasses import dataclass +from pathlib import Path + +from flask import Flask +from flask_restx.swagger import Swagger + +logger = logging.getLogger(__name__) + +API_ROOT = Path(__file__).resolve().parents[1] +if str(API_ROOT) not in sys.path: + sys.path.insert(0, str(API_ROOT)) + + +@dataclass(frozen=True) +class SpecTarget: + route: str + filename: str + + +SPEC_TARGETS: tuple[SpecTarget, ...] = ( + SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json"), + SpecTarget(route="/api/swagger.json", filename="web-swagger.json"), + SpecTarget(route="/v1/swagger.json", filename="service-swagger.json"), +) + +_ORIGINAL_REGISTER_MODEL = Swagger.register_model +_ORIGINAL_REGISTER_FIELD = Swagger.register_field + + +def _apply_runtime_defaults() -> None: + """Force the small config surface required for Swagger generation.""" + + os.environ.setdefault("SECRET_KEY", "spec-export") + os.environ.setdefault("STORAGE_TYPE", "local") + os.environ.setdefault("STORAGE_LOCAL_PATH", "/tmp/dify-storage") + os.environ.setdefault("SWAGGER_UI_ENABLED", "true") + + from configs import dify_config + + dify_config.SECRET_KEY = os.environ["SECRET_KEY"] + dify_config.STORAGE_TYPE = "local" + dify_config.STORAGE_LOCAL_PATH = os.environ["STORAGE_LOCAL_PATH"] + dify_config.SWAGGER_UI_ENABLED = os.environ["SWAGGER_UI_ENABLED"].lower() == "true" + + +def _patch_swagger_for_inline_nested_dicts() -> None: + """Teach Flask-RESTX Swagger generation to tolerate inline nested field maps. + + Some existing controllers use `fields.Nested({...})` with a raw field mapping + instead of a named `api.model(...)`. Flask-RESTX crashes on those anonymous + dicts during schema registration, so this helper upgrades them into temporary + named models at export time. + """ + + if getattr(Swagger, "_dify_inline_nested_dict_patch", False): + return + + def get_or_create_inline_model(self: Swagger, nested_fields: dict[object, object]) -> object: + anonymous_models = getattr(self, "_anonymous_inline_models", None) + if anonymous_models is None: + anonymous_models = {} + self._anonymous_inline_models = anonymous_models + + anonymous_name = anonymous_models.get(id(nested_fields)) + if anonymous_name is None: + anonymous_name = f"_AnonymousInlineModel{len(anonymous_models) + 1}" + anonymous_models[id(nested_fields)] = anonymous_name + self.api.model(anonymous_name, nested_fields) + + return self.api.models[anonymous_name] + + def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]: + if isinstance(model, dict): + model = get_or_create_inline_model(self, model) + + return _ORIGINAL_REGISTER_MODEL(self, model) + + def register_field_with_inline_dict_support(self: Swagger, field: object) -> None: + nested = getattr(field, "nested", None) + if isinstance(nested, dict): + field.model = get_or_create_inline_model(self, nested) # type: ignore + + _ORIGINAL_REGISTER_FIELD(self, field) + + Swagger.register_model = register_model_with_inline_dict_support + Swagger.register_field = register_field_with_inline_dict_support + Swagger._dify_inline_nested_dict_patch = True + + +def create_spec_app() -> Flask: + """Build a minimal Flask app that only mounts the Swagger-producing blueprints.""" + + _apply_runtime_defaults() + _patch_swagger_for_inline_nested_dicts() + + app = Flask(__name__) + + from controllers.console import bp as console_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + app.register_blueprint(console_bp) + app.register_blueprint(web_bp) + app.register_blueprint(service_api_bp) + + return app + + +def generate_specs(output_dir: Path) -> list[Path]: + """Write all Swagger specs to `output_dir` and return the written paths.""" + + output_dir.mkdir(parents=True, exist_ok=True) + + app = create_spec_app() + client = app.test_client() + + written_paths: list[Path] = [] + for target in SPEC_TARGETS: + response = client.get(target.route) + if response.status_code != 200: + raise RuntimeError(f"failed to fetch {target.route}: {response.status_code}") + + payload = response.get_json() + if not isinstance(payload, dict): + raise RuntimeError(f"unexpected response payload for {target.route}") + + output_path = output_dir / target.filename + output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + written_paths.append(output_path) + + return written_paths + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "-o", + "--output-dir", + type=Path, + default=Path("openapi"), + help="Directory where the Swagger JSON files will be written.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + written_paths = generate_specs(args.output_dir) + + for path in written_paths: + logger.debug(path) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 6b904b7d0d..fc118df5bc 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention,workflow_based_app_execution" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_publisher,trigger_refresh_executor,retention,workflow_based_app_execution" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" @@ -119,14 +119,16 @@ elif [[ "${MODE}" == "job" ]]; then else if [[ "${DEBUG}" == "true" ]]; then - exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug + export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0} + export PORT=${DIFY_PORT:-5001} + exec python -m app else exec gunicorn \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --workers ${SERVER_WORKER_AMOUNT:-1} \ - --worker-class ${SERVER_WORKER_CLASS:-gevent} \ + --worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \ --worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \ --timeout ${GUNICORN_TIMEOUT:-200} \ - app:app + app:socketio_app fi fi diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py index 5a8d0ee6f4..dff558988c 100644 --- a/api/enterprise/telemetry/draft_trace.py +++ b/api/enterprise/telemetry/draft_trace.py @@ -3,10 +3,9 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any -from graphon.enums import WorkflowNodeExecutionMetadataKey - from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName from core.telemetry import emit as telemetry_emit +from graphon.enums import WorkflowNodeExecutionMetadataKey from models.workflow import WorkflowNodeExecutionModel diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py index 9cda0bf90a..c564ace584 100644 --- a/api/enterprise/telemetry/metric_handler.py +++ b/api/enterprise/telemetry/metric_handler.py @@ -329,7 +329,7 @@ class EnterpriseMetricHandler: return include_content = exporter.include_content - attrs: dict = { + attrs: dict[str, Any] = { "dify.message.id": payload.get("message_id"), "dify.tenant_id": envelope.tenant_id, "dify.event.id": envelope.event_id, diff --git a/api/enums/quota_type.py b/api/enums/quota_type.py index 9f511b88ef..a10ac21f69 100644 --- a/api/enums/quota_type.py +++ b/api/enums/quota_type.py @@ -1,56 +1,17 @@ -import logging -from dataclasses import dataclass from enum import StrEnum, auto -logger = logging.getLogger(__name__) - - -@dataclass -class QuotaCharge: - """ - Result of a quota consumption operation. - - Attributes: - success: Whether the quota charge succeeded - charge_id: UUID for refund, or None if failed/disabled - """ - - success: bool - charge_id: str | None - _quota_type: "QuotaType" - - def refund(self) -> None: - """ - Refund this quota charge. - - Safe to call even if charge failed or was disabled. - This method guarantees no exceptions will be raised. - """ - if self.charge_id: - self._quota_type.refund(self.charge_id) - logger.info("Refunded quota for %s with charge_id: %s", self._quota_type.value, self.charge_id) - class QuotaType(StrEnum): """ Supported quota types for tenant feature usage. - - Add additional types here whenever new billable features become available. """ - # Trigger execution quota TRIGGER = auto() - - # Workflow execution quota WORKFLOW = auto() - UNLIMITED = auto() @property def billing_key(self) -> str: - """ - Get the billing key for the feature. - """ match self: case QuotaType.TRIGGER: return "trigger_event" @@ -58,152 +19,3 @@ class QuotaType(StrEnum): return "api_rate_limit" case _: raise ValueError(f"Invalid quota type: {self}") - - def consume(self, tenant_id: str, amount: int = 1) -> QuotaCharge: - """ - Consume quota for the feature. - - Args: - tenant_id: The tenant identifier - amount: Amount to consume (default: 1) - - Returns: - QuotaCharge with success status and charge_id for refund - - Raises: - QuotaExceededError: When quota is insufficient - """ - from configs import dify_config - from services.billing_service import BillingService - from services.errors.app import QuotaExceededError - - if not dify_config.BILLING_ENABLED: - logger.debug("Billing disabled, allowing request for %s", tenant_id) - return QuotaCharge(success=True, charge_id=None, _quota_type=self) - - logger.info("Consuming %d %s quota for tenant %s", amount, self.value, tenant_id) - - if amount <= 0: - raise ValueError("Amount to consume must be greater than 0") - - try: - response = BillingService.update_tenant_feature_plan_usage(tenant_id, self.billing_key, delta=amount) - - if response.get("result") != "success": - logger.warning( - "Failed to consume quota for %s, feature %s details: %s", - tenant_id, - self.value, - response.get("detail"), - ) - raise QuotaExceededError(feature=self.value, tenant_id=tenant_id, required=amount) - - charge_id = response.get("history_id") - logger.debug( - "Successfully consumed %d %s quota for tenant %s, charge_id: %s", - amount, - self.value, - tenant_id, - charge_id, - ) - return QuotaCharge(success=True, charge_id=charge_id, _quota_type=self) - - except QuotaExceededError: - raise - except Exception: - # fail-safe: allow request on billing errors - logger.exception("Failed to consume quota for %s, feature %s", tenant_id, self.value) - return unlimited() - - def check(self, tenant_id: str, amount: int = 1) -> bool: - """ - Check if tenant has sufficient quota without consuming. - - Args: - tenant_id: The tenant identifier - amount: Amount to check (default: 1) - - Returns: - True if quota is sufficient, False otherwise - """ - from configs import dify_config - - if not dify_config.BILLING_ENABLED: - return True - - if amount <= 0: - raise ValueError("Amount to check must be greater than 0") - - try: - remaining = self.get_remaining(tenant_id) - return remaining >= amount if remaining != -1 else True - except Exception: - logger.exception("Failed to check quota for %s, feature %s", tenant_id, self.value) - # fail-safe: allow request on billing errors - return True - - def refund(self, charge_id: str) -> None: - """ - Refund quota using charge_id from consume(). - - This method guarantees no exceptions will be raised. - All errors are logged but silently handled. - - Args: - charge_id: The UUID returned from consume() - """ - try: - from configs import dify_config - from services.billing_service import BillingService - - if not dify_config.BILLING_ENABLED: - return - - if not charge_id: - logger.warning("Cannot refund: charge_id is empty") - return - - logger.info("Refunding %s quota with charge_id: %s", self.value, charge_id) - - response = BillingService.refund_tenant_feature_plan_usage(charge_id) - if response.get("result") == "success": - logger.debug("Successfully refunded %s quota, charge_id: %s", self.value, charge_id) - else: - logger.warning("Refund failed for charge_id: %s", charge_id) - - except Exception: - # Catch ALL exceptions - refund must never fail - logger.exception("Failed to refund quota for charge_id: %s", charge_id) - # Don't raise - refund is best-effort and must be silent - - def get_remaining(self, tenant_id: str) -> int: - """ - Get remaining quota for the tenant. - - Args: - tenant_id: The tenant identifier - - Returns: - Remaining quota amount - """ - from services.billing_service import BillingService - - try: - usage_info = BillingService.get_tenant_feature_plan_usage(tenant_id, self.billing_key) - # Assuming the API returns a dict with 'remaining' or 'limit' and 'used' - if isinstance(usage_info, dict): - return usage_info.get("remaining", 0) - # If it returns a simple number, treat it as remaining - return int(usage_info) if usage_info else 0 - except Exception: - logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, self.value) - return -1 - - -def unlimited() -> QuotaCharge: - """ - Return a quota charge for unlimited quota. - - This is useful for features that are not subject to quota limits, such as the UNLIMITED quota type. - """ - return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index b7e7a6e60f..0c535a1c5b 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -6,9 +6,9 @@ import click from sqlalchemy import select from werkzeug.exceptions import NotFound +from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.document_index_event import document_index_created -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document from models.enums import IndexingStatus @@ -22,24 +22,25 @@ def handle(sender, **kwargs): document_ids = kwargs.get("document_ids", []) documents = [] start_at = time.perf_counter() - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) + with session_factory.create_session() as session: + for document_id in document_ids: + logger.info(click.style(f"Start process document: {document_id}", fg="green")) - document = db.session.scalar( - select(Document).where( - Document.id == document_id, - Document.dataset_id == dataset_id, + document = session.scalar( + select(Document).where( + Document.id == document_id, + Document.dataset_id == dataset_id, + ) ) - ) - if not document: - raise NotFound("Document not found") + if not document: + raise NotFound("Document not found") - document.indexing_status = IndexingStatus.PARSING - document.processing_started_at = naive_utc_now() - documents.append(document) - db.session.add(document) - db.session.commit() + document.indexing_status = IndexingStatus.PARSING + document.processing_started_at = naive_utc_now() + documents.append(document) + session.add(document) + session.commit() with contextlib.suppress(Exception): try: diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py index 57412cc4ad..38e102d5fd 100644 --- a/api/events/event_handlers/create_installed_app_when_app_created.py +++ b/api/events/event_handlers/create_installed_app_when_app_created.py @@ -1,5 +1,5 @@ +from core.db.session_factory import session_factory from events.app_event import app_was_created -from extensions.ext_database import db from models.model import InstalledApp @@ -12,5 +12,6 @@ def handle(sender, **kwargs): app_id=app.id, app_owner_tenant_id=app.tenant_id, ) - db.session.add(installed_app) - db.session.commit() + with session_factory.create_session() as session: + session.add(installed_app) + session.commit() diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 84be592b1a..5e2a456dce 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -1,5 +1,5 @@ +from core.db.session_factory import session_factory from events.app_event import app_was_created -from extensions.ext_database import db from models.enums import CustomizeTokenStrategy from models.model import Site @@ -22,6 +22,6 @@ def handle(sender, **kwargs): created_by=app.created_by, updated_by=app.updated_by, ) - - db.session.add(site) - db.session.commit() + with session_factory.create_session() as session: + session.add(site) + session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 7bd8e88231..f1196445ed 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,12 +1,12 @@ import logging -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.tool.entities import ToolEntity - from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager +from core.workflow.human_input_adapter import adapt_node_config_for_graph from events.app_event import app_draft_workflow_was_synced +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) @@ -20,7 +20,8 @@ def handle(sender, **kwargs): for node_data in synced_draft_workflow.graph_dict.get("nodes", []): if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL: try: - tool_entity = ToolEntity.model_validate(node_data["data"]) + adapted_node_data = adapt_node_config_for_graph(node_data) + tool_entity = ToolEntity.model_validate(adapted_node_data["data"]) provider_type = ToolProviderType(tool_entity.provider_type.value) tool_runtime = ToolManager.get_tool_runtime( provider_type=provider_type, diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 86b5b2bbf0..6769b94cde 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,11 +1,11 @@ from typing import cast -from graphon.nodes import BuiltinNodeTypes from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db +from graphon.nodes import BuiltinNodeTypes from models.dataset import AppDatasetJoin from models.workflow import Workflow diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 86b0550187..340f514fcc 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -9,6 +9,7 @@ from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp +from extensions.redis_names import normalize_redis_key_prefix class _CelerySentinelKwargsDict(TypedDict): @@ -16,9 +17,10 @@ class _CelerySentinelKwargsDict(TypedDict): password: str | None -class CelerySentinelTransportDict(TypedDict): +class CelerySentinelTransportDict(TypedDict, total=False): master_name: str | None sentinel_kwargs: _CelerySentinelKwargsDict + global_keyprefix: str class CelerySSLOptionsDict(TypedDict): @@ -61,15 +63,31 @@ def get_celery_ssl_options() -> CelerySSLOptionsDict | None: def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]: """Get broker transport options (e.g. Redis Sentinel) for Celery connections.""" + transport_options: CelerySentinelTransportDict | dict[str, Any] if dify_config.CELERY_USE_SENTINEL: - return CelerySentinelTransportDict( + transport_options = CelerySentinelTransportDict( master_name=dify_config.CELERY_SENTINEL_MASTER_NAME, sentinel_kwargs=_CelerySentinelKwargsDict( socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, password=dify_config.CELERY_SENTINEL_PASSWORD, ), ) - return {} + else: + transport_options = {} + + global_keyprefix = get_celery_redis_global_keyprefix() + if global_keyprefix: + transport_options["global_keyprefix"] = global_keyprefix + + return transport_options + + +def get_celery_redis_global_keyprefix() -> str | None: + """Return the Redis transport prefix for Celery when namespace isolation is enabled.""" + normalized_prefix = normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX) + if not normalized_prefix: + return None + return f"{normalized_prefix}:" def init_app(app: DifyApp) -> Celery: diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index b9e592cadb..9f7f73765e 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -3,7 +3,7 @@ import logging import ssl from collections.abc import Callable from datetime import timedelta -from typing import TYPE_CHECKING, Any, Union +from typing import Any, Union, cast import redis from redis import RedisError @@ -14,20 +14,30 @@ from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.retry import Retry from redis.sentinel import Sentinel +from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp +from extensions.redis_names import ( + normalize_redis_key_prefix, + serialize_redis_name, + serialize_redis_name_arg, + serialize_redis_name_args, +) from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel -if TYPE_CHECKING: - from redis.lock import Lock - logger = logging.getLogger(__name__) +_normalize_redis_key_prefix = normalize_redis_key_prefix +_serialize_redis_name = serialize_redis_name +_serialize_redis_name_arg = serialize_redis_name_arg +_serialize_redis_name_args = serialize_redis_name_args + + class RedisClientWrapper: """ A wrapper class for the Redis client that addresses the issue where the global @@ -58,74 +68,189 @@ class RedisClientWrapper: if self._client is None: self._client = client - if TYPE_CHECKING: - # Type hints for IDE support and static analysis - # These are not executed at runtime but provide type information - def get(self, name: str | bytes) -> Any: ... - - def set( - self, - name: str | bytes, - value: Any, - ex: int | None = None, - px: int | None = None, - nx: bool = False, - xx: bool = False, - keepttl: bool = False, - get: bool = False, - exat: int | None = None, - pxat: int | None = None, - ) -> Any: ... - - def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: ... - def setnx(self, name: str | bytes, value: Any) -> Any: ... - def delete(self, *names: str | bytes) -> Any: ... - def incr(self, name: str | bytes, amount: int = 1) -> Any: ... - def expire( - self, - name: str | bytes, - time: int | timedelta, - nx: bool = False, - xx: bool = False, - gt: bool = False, - lt: bool = False, - ) -> Any: ... - def lock( - self, - name: str, - timeout: float | None = None, - sleep: float = 0.1, - blocking: bool = True, - blocking_timeout: float | None = None, - thread_local: bool = True, - ) -> Lock: ... - def zadd( - self, - name: str | bytes, - mapping: dict[str | bytes | int | float, float | int | str | bytes], - nx: bool = False, - xx: bool = False, - ch: bool = False, - incr: bool = False, - gt: bool = False, - lt: bool = False, - ) -> Any: ... - def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ... - def zcard(self, name: str | bytes) -> Any: ... - def getdel(self, name: str | bytes) -> Any: ... - def pubsub(self) -> PubSub: ... - def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: ... - - def __getattr__(self, item: str) -> Any: + def _require_client(self) -> redis.Redis | RedisCluster: if self._client is None: raise RuntimeError("Redis client is not initialized. Call init_app first.") - return getattr(self._client, item) + return self._client + + def _get_prefix(self) -> str: + return dify_config.REDIS_KEY_PREFIX + + def get(self, name: str | bytes) -> Any: + return self._require_client().get(_serialize_redis_name_arg(name, self._get_prefix())) + + def set( + self, + name: str | bytes, + value: Any, + ex: int | None = None, + px: int | None = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + get: bool = False, + exat: int | None = None, + pxat: int | None = None, + ) -> Any: + return self._require_client().set( + _serialize_redis_name_arg(name, self._get_prefix()), + value, + ex=ex, + px=px, + nx=nx, + xx=xx, + keepttl=keepttl, + get=get, + exat=exat, + pxat=pxat, + ) + + def setex(self, name: str | bytes, time: int | timedelta, value: Any) -> Any: + return self._require_client().setex(_serialize_redis_name_arg(name, self._get_prefix()), time, value) + + def setnx(self, name: str | bytes, value: Any) -> Any: + return self._require_client().setnx(_serialize_redis_name_arg(name, self._get_prefix()), value) + + def delete(self, *names: str | bytes) -> Any: + return self._require_client().delete(*_serialize_redis_name_args(names, self._get_prefix())) + + def incr(self, name: str | bytes, amount: int = 1) -> Any: + return self._require_client().incr(_serialize_redis_name_arg(name, self._get_prefix()), amount) + + def expire( + self, + name: str | bytes, + time: int | timedelta, + nx: bool = False, + xx: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: + return self._require_client().expire( + _serialize_redis_name_arg(name, self._get_prefix()), + time, + nx=nx, + xx=xx, + gt=gt, + lt=lt, + ) + + def exists(self, *names: str | bytes) -> Any: + return self._require_client().exists(*_serialize_redis_name_args(names, self._get_prefix())) + + def ttl(self, name: str | bytes) -> Any: + return self._require_client().ttl(_serialize_redis_name_arg(name, self._get_prefix())) + + def getdel(self, name: str | bytes) -> Any: + return self._require_client().getdel(_serialize_redis_name_arg(name, self._get_prefix())) + + def lock( + self, + name: str, + timeout: float | None = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: float | None = None, + thread_local: bool = True, + ) -> Any: + return self._require_client().lock( + _serialize_redis_name(name, self._get_prefix()), + timeout=timeout, + sleep=sleep, + blocking=blocking, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def hset(self, name: str | bytes, *args: Any, **kwargs: Any) -> Any: + return self._require_client().hset(_serialize_redis_name_arg(name, self._get_prefix()), *args, **kwargs) + + def hgetall(self, name: str | bytes) -> Any: + return self._require_client().hgetall(_serialize_redis_name_arg(name, self._get_prefix())) + + def hdel(self, name: str | bytes, *keys: str | bytes) -> Any: + return self._require_client().hdel(_serialize_redis_name_arg(name, self._get_prefix()), *keys) + + def hlen(self, name: str | bytes) -> Any: + return self._require_client().hlen(_serialize_redis_name_arg(name, self._get_prefix())) + + def zadd( + self, + name: str | bytes, + mapping: dict[str | bytes | int | float, float | int | str | bytes], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + gt: bool = False, + lt: bool = False, + ) -> Any: + return self._require_client().zadd( + _serialize_redis_name_arg(name, self._get_prefix()), + cast(Any, mapping), + nx=nx, + xx=xx, + ch=ch, + incr=incr, + gt=gt, + lt=lt, + ) + + def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: + return self._require_client().zremrangebyscore(_serialize_redis_name_arg(name, self._get_prefix()), min, max) + + def zcard(self, name: str | bytes) -> Any: + return self._require_client().zcard(_serialize_redis_name_arg(name, self._get_prefix())) + + def pubsub(self) -> PubSub: + return self._require_client().pubsub() + + def pipeline(self, transaction: bool = True, shard_hint: str | None = None) -> Any: + return self._require_client().pipeline(transaction=transaction, shard_hint=shard_hint) + + def __getattr__(self, item: str) -> Any: + return getattr(self._require_client(), item) redis_client: RedisClientWrapper = RedisClientWrapper() _pubsub_redis_client: redis.Redis | RedisCluster | None = None +class RedisSSLParamsDict(TypedDict): + ssl_cert_reqs: int + ssl_ca_certs: str | None + ssl_certfile: str | None + ssl_keyfile: str | None + + +class RedisHealthParamsDict(TypedDict): + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + health_check_interval: int | None + + +class RedisClusterHealthParamsDict(TypedDict): + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + + +class RedisBaseParamsDict(TypedDict): + username: str | None + password: str | None + db: int + encoding: str + encoding_errors: str + decode_responses: bool + protocol: int + cache_config: CacheConfig | None + retry: Retry + socket_timeout: float | None + socket_connect_timeout: float | None + health_check_interval: int | None + + def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]: """Get SSL configuration for Redis connection.""" if not dify_config.REDIS_USE_SSL: @@ -171,17 +296,17 @@ def _get_retry_policy() -> Retry: ) -def _get_connection_health_params() -> dict[str, Any]: +def _get_connection_health_params() -> RedisHealthParamsDict: """Get connection health and retry parameters for standalone and Sentinel Redis clients.""" - return { - "retry": _get_retry_policy(), - "socket_timeout": dify_config.REDIS_SOCKET_TIMEOUT, - "socket_connect_timeout": dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, - "health_check_interval": dify_config.REDIS_HEALTH_CHECK_INTERVAL, - } + return RedisHealthParamsDict( + retry=_get_retry_policy(), + socket_timeout=dify_config.REDIS_SOCKET_TIMEOUT, + socket_connect_timeout=dify_config.REDIS_SOCKET_CONNECT_TIMEOUT, + health_check_interval=dify_config.REDIS_HEALTH_CHECK_INTERVAL, + ) -def _get_cluster_connection_health_params() -> dict[str, Any]: +def _get_cluster_connection_health_params() -> RedisClusterHealthParamsDict: """Get retry and timeout parameters for Redis Cluster clients. RedisCluster does not support ``health_check_interval`` as a constructor @@ -189,26 +314,31 @@ def _get_cluster_connection_health_params() -> dict[str, Any]: here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout`` are passed through. """ - params = _get_connection_health_params() - return {k: v for k, v in params.items() if k != "health_check_interval"} - - -def _get_base_redis_params() -> dict[str, Any]: - """Get base Redis connection parameters including retry and health policy.""" - return { - "username": dify_config.REDIS_USERNAME, - "password": dify_config.REDIS_PASSWORD or None, - "db": dify_config.REDIS_DB, - "encoding": "utf-8", - "encoding_errors": "strict", - "decode_responses": False, - "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, - "cache_config": _get_cache_configuration(), - **_get_connection_health_params(), + health_params = _get_connection_health_params() + result: RedisClusterHealthParamsDict = { + "retry": health_params["retry"], + "socket_timeout": health_params["socket_timeout"], + "socket_connect_timeout": health_params["socket_connect_timeout"], } + return result -def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: +def _get_base_redis_params() -> RedisBaseParamsDict: + """Get base Redis connection parameters including retry and health policy.""" + return RedisBaseParamsDict( + username=dify_config.REDIS_USERNAME, + password=dify_config.REDIS_PASSWORD or None, + db=dify_config.REDIS_DB, + encoding="utf-8", + encoding_errors="strict", + decode_responses=False, + protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, + cache_config=_get_cache_configuration(), + **_get_connection_health_params(), + ) + + +def _create_sentinel_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]: """Create Redis client using Sentinel configuration.""" if not dify_config.REDIS_SENTINELS: raise ValueError("REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True") @@ -232,7 +362,8 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, sentinel_kwargs=sentinel_kwargs, ) - master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) + params: dict[str, Any] = {**redis_params} + master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **params) return master @@ -259,18 +390,16 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: return cluster -def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis, RedisCluster]: +def _create_standalone_client(redis_params: RedisBaseParamsDict) -> Union[redis.Redis, RedisCluster]: """Create standalone Redis client.""" connection_class, ssl_kwargs = _get_ssl_configuration() - params = {**redis_params} - params.update( - { - "host": dify_config.REDIS_HOST, - "port": dify_config.REDIS_PORT, - "connection_class": connection_class, - } - ) + params: dict[str, Any] = { + **redis_params, + "host": dify_config.REDIS_HOST, + "port": dify_config.REDIS_PORT, + "connection_class": connection_class, + } if dify_config.REDIS_MAX_CONNECTIONS: params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS @@ -293,8 +422,8 @@ def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | kwargs["max_connections"] = max_conns return RedisCluster.from_url(pubsub_url, **kwargs) - health_params = _get_connection_health_params() - kwargs = {**health_params} + standalone_health_params: dict[str, Any] = dict(_get_connection_health_params()) + kwargs = {**standalone_health_params} if max_conns: kwargs["max_connections"] = max_conns return redis.Redis.from_url(pubsub_url, **kwargs) diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 5cc58f27c4..69d1f1ab07 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,11 +5,12 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk - from graphon.model_runtime.errors.invoke import InvokeRateLimitError from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException + from graphon.model_runtime.errors.invoke import InvokeRateLimitError + try: from langfuse._utils import parse_error diff --git a/api/extensions/ext_session_factory.py b/api/extensions/ext_session_factory.py index 0eb43d66f4..e19ccd11e5 100644 --- a/api/extensions/ext_session_factory.py +++ b/api/extensions/ext_session_factory.py @@ -1,7 +1,9 @@ +from flask import Flask + from core.db.session_factory import configure_session_factory from extensions.ext_database import db -def init_app(app): +def init_app(app: Flask): with app.app_context(): configure_session_factory(db.engine) diff --git a/api/extensions/ext_socketio.py b/api/extensions/ext_socketio.py new file mode 100644 index 0000000000..5ed82bac8d --- /dev/null +++ b/api/extensions/ext_socketio.py @@ -0,0 +1,5 @@ +import socketio # type: ignore[reportMissingTypeStubs] + +from configs import dify_config + +sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS) diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index db599c5d49..64ff0f0674 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value +from graphon.enums import WorkflowNodeExecutionStatus from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 2745141431..7f77a0437a 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -20,12 +20,12 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, cast -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import sessionmaker from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string +from graphon.enums import WorkflowExecutionStatus from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index d0f3e2e244..544109276d 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -3,14 +3,14 @@ import logging import os import time -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 37952d6464..dc7654a25c 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -13,10 +13,6 @@ from collections.abc import Sequence from datetime import datetime from typing import Any -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -26,6 +22,10 @@ from core.repositories.factory import OrderConfig, WorkflowNodeExecutionReposito from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 1dd92caeae..ad83826427 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -37,12 +37,7 @@ def trace_span[**P, R](handler_class: type[SpanHandler] | None = None) -> Callab handler = _get_handler_instance(handler_class or SpanHandler) tracer = get_tracer(__name__) - return handler.wrapper( - tracer=tracer, - wrapped=func, - args=args, - kwargs=kwargs, - ) + return handler.wrapper(tracer, func, *args, **kwargs) return cast(Callable[P, R], wrapper) diff --git a/api/extensions/otel/decorators/handler.py b/api/extensions/otel/decorators/handler.py index e465a615a6..b0d9fa7af6 100644 --- a/api/extensions/otel/decorators/handler.py +++ b/api/extensions/otel/decorators/handler.py @@ -1,8 +1,8 @@ import inspect -from collections.abc import Callable, Mapping +from collections.abc import Callable from typing import Any -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer class SpanHandler: @@ -16,9 +16,9 @@ class SpanHandler: exceptions. Handlers can override the wrapper method to customize behavior. """ - _signature_cache: dict[Callable[..., Any], inspect.Signature] = {} + _signature_cache: dict[Callable[..., object], inspect.Signature] = {} - def _build_span_name(self, wrapped: Callable[..., Any]) -> str: + def _build_span_name[**P, R](self, wrapped: Callable[P, R]) -> str: """ Build the span name from the wrapped function. @@ -29,11 +29,11 @@ class SpanHandler: """ return f"{wrapped.__module__}.{wrapped.__qualname__}" - def _extract_arguments[T]( + def _extract_arguments[**P, R]( self, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, ) -> dict[str, Any] | None: """ Extract function arguments using inspect.signature. @@ -59,13 +59,13 @@ class SpanHandler: except Exception: return None - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: """ Fully control the wrapper behavior. diff --git a/api/extensions/otel/decorators/handlers/generate_handler.py b/api/extensions/otel/decorators/handlers/generate_handler.py index cc6c75304f..df5142c310 100644 --- a/api/extensions/otel/decorators/handlers/generate_handler.py +++ b/api/extensions/otel/decorators/handlers/generate_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -15,15 +14,15 @@ logger = logging.getLogger(__name__) class AppGenerateHandler(SpanHandler): """Span handler for ``AppGenerateService.generate``.""" - def wrapper[T]( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., T], - args: tuple[object, ...], - kwargs: Mapping[str, object], - ) -> T: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py index 8abd60197c..6b2112ceb2 100644 --- a/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py +++ b/api/extensions/otel/decorators/handlers/workflow_app_runner_handler.py @@ -1,8 +1,7 @@ import logging -from collections.abc import Callable, Mapping -from typing import Any +from collections.abc import Callable -from opentelemetry.trace import SpanKind, Status, StatusCode +from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer from opentelemetry.util.types import AttributeValue from extensions.otel.decorators.handler import SpanHandler @@ -14,15 +13,15 @@ logger = logging.getLogger(__name__) class WorkflowAppRunnerHandler(SpanHandler): """Span handler for ``WorkflowAppRunner.run``.""" - def wrapper( + def wrapper[**P, R]( self, - tracer: Any, - wrapped: Callable[..., Any], - args: tuple[Any, ...], - kwargs: Mapping[str, Any], - ) -> Any: + tracer: Tracer, + wrapped: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> R: try: - arguments = self._extract_arguments(wrapped, args, kwargs) + arguments = self._extract_arguments(wrapped, *args, **kwargs) if not arguments: return wrapped(*args, **kwargs) diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index 23d324f9ea..fbf379b3e5 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -10,17 +10,17 @@ Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when import json from typing import Any, Protocol -from graphon.enums import BuiltinNodeTypes -from graphon.file import File -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes +from graphon.enums import BuiltinNodeTypes +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment def should_include_content() -> bool: diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 335c5cc29e..ec3c78a12d 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -6,12 +6,12 @@ import logging from collections.abc import Mapping from typing import Any -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index 6df5f62c15..56672d1fd4 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -6,13 +6,13 @@ import logging from collections.abc import Sequence from typing import Any -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.variables import Segment from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment logger = logging.getLogger(__name__) diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index b9fdd9e1ca..75ddbba448 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -2,14 +2,14 @@ Parser for tool nodes that captures tool-specific metadata. """ -from graphon.enums import WorkflowNodeExecutionMetadataKey -from graphon.graph_events import GraphNodeEventBase -from graphon.nodes.base.node import Node -from graphon.nodes.tool.entities import ToolNodeData from opentelemetry.trace import Span from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import ToolAttributes +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.nodes.tool.entities import ToolNodeData class ToolNodeOTelParser: diff --git a/api/extensions/redis_names.py b/api/extensions/redis_names.py new file mode 100644 index 0000000000..9e63416daf --- /dev/null +++ b/api/extensions/redis_names.py @@ -0,0 +1,32 @@ +from configs import dify_config + + +def normalize_redis_key_prefix(prefix: str | None) -> str: + """Normalize the configured Redis key prefix for consistent runtime use.""" + if prefix is None: + return "" + return prefix.strip() + + +def get_redis_key_prefix() -> str: + """Read and normalize the current Redis key prefix from config.""" + return normalize_redis_key_prefix(dify_config.REDIS_KEY_PREFIX) + + +def serialize_redis_name(name: str, prefix: str | None = None) -> str: + """Convert a logical Redis name into the physical name used in Redis.""" + normalized_prefix = get_redis_key_prefix() if prefix is None else normalize_redis_key_prefix(prefix) + if not normalized_prefix: + return name + return f"{normalized_prefix}:{name}" + + +def serialize_redis_name_arg(name: str | bytes, prefix: str | None = None) -> str | bytes: + """Prefix string Redis names while preserving bytes inputs unchanged.""" + if isinstance(name, bytes): + return name + return serialize_redis_name(name, prefix) + + +def serialize_redis_name_args(names: tuple[str | bytes, ...], prefix: str | None = None) -> tuple[str | bytes, ...]: + return tuple(serialize_redis_name_arg(name, prefix) for name in names) diff --git a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py index 18eed4e481..05492327c8 100644 --- a/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +++ b/api/extensions/storage/clickzetta_volume/clickzetta_volume_storage.py @@ -10,6 +10,7 @@ import tempfile from collections.abc import Generator from io import BytesIO from pathlib import Path +from typing import Any import clickzetta from pydantic import BaseModel, model_validator @@ -39,7 +40,7 @@ class ClickZettaVolumeConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): """Validate the configuration values. This method will first try to use CLICKZETTA_VOLUME_* environment variables, diff --git a/api/extensions/storage/clickzetta_volume/file_lifecycle.py b/api/extensions/storage/clickzetta_volume/file_lifecycle.py index 86b1bba544..1cb940b797 100644 --- a/api/extensions/storage/clickzetta_volume/file_lifecycle.py +++ b/api/extensions/storage/clickzetta_volume/file_lifecycle.py @@ -65,7 +65,7 @@ class FileMetadata: return data @classmethod - def from_dict(cls, data: dict) -> FileMetadata: + def from_dict(cls, data: dict[str, Any]) -> FileMetadata: """Create instance from dictionary""" data = data.copy() data["created_at"] = datetime.fromisoformat(data["created_at"]) @@ -459,7 +459,7 @@ class FileLifecycleManager: newest_file=None, ) - def _create_version_backup(self, filename: str, metadata: dict): + def _create_version_backup(self, filename: str, metadata: dict[str, Any]): """Create version backup""" try: # Read current file content @@ -487,7 +487,7 @@ class FileLifecycleManager: logger.warning("Failed to load metadata: %s", e) return {} - def _save_metadata(self, metadata_dict: dict): + def _save_metadata(self, metadata_dict: dict[str, Any]): """Save metadata file""" try: metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False) diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index 96f5915ff0..cd7f7db295 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -2,6 +2,7 @@ import logging import os from collections.abc import Generator from pathlib import Path +from typing import Any import opendal from dotenv import dotenv_values @@ -19,7 +20,7 @@ def _get_opendal_kwargs(*, scheme: str, env_file_path: str = ".env", prefix: str if key.startswith(config_prefix): kwargs[key[len(config_prefix) :].lower()] = value - file_env_vars: dict = dotenv_values(env_file_path) or {} + file_env_vars: dict[str, Any] = dotenv_values(env_file_path) or {} for key, value in file_env_vars.items(): if key.startswith(config_prefix) and key[len(config_prefix) :].lower() not in kwargs and value: kwargs[key[len(config_prefix) :].lower()] = value diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py index 7516d18c8e..4fb976f0e7 100644 --- a/api/factories/file_factory/builders.py +++ b/api/factories/file_factory/builders.py @@ -7,12 +7,12 @@ import uuid from collections.abc import Mapping, Sequence from typing import Any -from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from sqlalchemy import select from core.app.file_access import FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.workflow.file_reference import build_file_reference -from extensions.ext_database import db +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type from models import ToolFile, UploadFile from .common import resolve_mapping_file_id @@ -135,29 +135,30 @@ def _build_from_local_file( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if row is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + row = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type", "custom"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - reference=build_file_reference(record_id=str(row.id)), - size=row.size, - storage_key=row.key, - ) + return File( + file_id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + file_type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) def _build_from_remote_url( @@ -179,32 +180,33 @@ def _build_from_remote_url( UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) - upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if upload_file is None: - raise ValueError("Invalid upload file") + with session_factory.create_session() as session: + upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") - detected_file_type = standardize_file_type( - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - ) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - reference=build_file_reference(record_id=str(upload_file.id)), - size=upload_file.size, - storage_key=upload_file.key, - ) + return File( + file_id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + file_type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) url = mapping.get("url") or mapping.get("remote_url") if not url: @@ -220,9 +222,9 @@ def _build_from_remote_url( ) return File( - id=mapping.get("id"), + file_id=mapping.get("id"), filename=filename, - type=file_type, + file_type=file_type, transfer_method=transfer_method, remote_url=url, mime_type=mime_type, @@ -247,30 +249,31 @@ def _build_from_tool_file( ToolFile.id == tool_file_id, ToolFile.tenant_id == tenant_id, ) - tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") + with session_factory.create_session() as session: + tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("id"), - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - reference=build_file_reference(record_id=str(tool_file.id)), - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) + return File( + file_id=mapping.get("id"), + filename=tool_file.name, + file_type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) def _build_from_datasource_file( @@ -289,31 +292,32 @@ def _build_from_datasource_file( UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) - datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + with session_factory.create_session() as session: + datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - file_type = _resolve_file_type( - detected_file_type=detected_file_type, - specified_type=mapping.get("type"), - strict_type_validation=strict_type_validation, - ) + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) - return File( - id=mapping.get("datasource_file_id"), - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - reference=build_file_reference(record_id=str(datasource_file.id)), - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) + return File( + file_id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + file_type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py index 5582b85c95..4b3d514238 100644 --- a/api/factories/file_factory/message_files.py +++ b/api/factories/file_factory/message_files.py @@ -4,9 +4,8 @@ from __future__ import annotations from collections.abc import Sequence -from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig - from core.app.file_access import FileAccessControllerProtocol +from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig from models import MessageFile from .builders import build_from_mapping diff --git a/api/factories/file_factory/remote.py b/api/factories/file_factory/remote.py index e5a7186007..9b8f94b1f3 100644 --- a/api/factories/file_factory/remote.py +++ b/api/factories/file_factory/remote.py @@ -19,8 +19,13 @@ from werkzeug.http import parse_options_header from core.helper import ssrf_proxy -def extract_filename(url_path: str, content_disposition: str | None) -> str | None: - """Extract a safe filename from Content-Disposition or the request URL path.""" +def extract_filename(url_or_path: str, content_disposition: str | None) -> str | None: + """Extract a safe filename from Content-Disposition or the request URL path. + + Handles full URLs, paths with query strings, hash fragments, and percent-encoded segments. + Query strings and hash fragments are stripped from the URL before extracting the basename. + Percent-encoded characters in the path are decoded safely. + """ filename: str | None = None if content_disposition: filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) @@ -47,8 +52,13 @@ def extract_filename(url_path: str, content_disposition: str | None) -> str | No filename = urllib.parse.unquote(raw) if not filename: - candidate = os.path.basename(url_path) - filename = urllib.parse.unquote(candidate) if candidate else None + # Parse the URL to extract just the path, stripping query strings and fragments + # This handles both full URLs and bare paths + parsed = urllib.parse.urlparse(url_or_path) + path = parsed.path + candidate = os.path.basename(path) + # Decode percent-encoded characters, with safe fallback for malformed input + filename = urllib.parse.unquote(candidate, errors="replace") if candidate else None if filename: filename = os.path.basename(filename) diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py index db3a7f3015..dba4c84407 100644 --- a/api/factories/file_factory/storage_keys.py +++ b/api/factories/file_factory/storage_keys.py @@ -5,12 +5,12 @@ from __future__ import annotations import uuid from collections.abc import Mapping, Sequence -from graphon.file import File, FileTransferMethod from sqlalchemy import select from sqlalchemy.orm import Session from core.app.file_access import FileAccessControllerProtocol from core.workflow.file_reference import build_file_reference, parse_file_reference +from graphon.file import File, FileTransferMethod from models import ToolFile, UploadFile diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 57205b5739..fd7acb14d3 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -8,6 +8,11 @@ shared conversion functions for legacy callers and tests. from collections.abc import Mapping, Sequence from typing import Any, cast +from configs import dify_config +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) from graphon.variables.exc import VariableError from graphon.variables.factory import ( TypeMismatchError, @@ -31,12 +36,6 @@ from graphon.variables.variables import ( VariableBase, ) -from configs import dify_config -from core.workflow.variable_prefixes import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, -) - __all__ = [ "TypeMismatchError", "UnsupportedSegmentTypeError", diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index b5acbbbcb4..d518114777 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -10,9 +10,9 @@ class _VarTypedDict(TypedDict, total=False): def serialize_value_type(v: _VarTypedDict | Segment) -> str: if isinstance(v, Segment): - return v.value_type.exposed_type().value + return str(v.value_type.exposed_type()) else: value_type = v.get("value_type") if value_type is None: raise ValueError("value_type is required but not provided") - return value_type.exposed_type().value + return str(value_type.exposed_type()) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 1afcbdb5b9..bf5c9ffcb1 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,10 +3,10 @@ from __future__ import annotations from datetime import datetime from typing import Any -from graphon.file import File from pydantic import Field, field_validator, model_validator from fields.base import ResponseModel +from graphon.file import File type JSONValue = Any @@ -96,7 +96,7 @@ class ConversationAnnotation(ResponseModel): class ConversationAnnotationHitHistory(ResponseModel): - annotation_id: str + annotation_id: str = Field(validation_alias="id") annotation_create_account: SimpleAccount | None = None created_at: int | None = None @@ -143,7 +143,7 @@ class MessageDetail(ResponseModel): query: str message: JSONValue message_tokens: int - answer: str + answer: str = Field(validation_alias="re_sign_file_url_answer") answer_tokens: int provider_response_latency: float from_source: str @@ -156,7 +156,7 @@ class MessageDetail(ResponseModel): created_at: int | None = None agent_thoughts: list[AgentThought] message_files: list[MessageFile] - metadata: JSONValue + metadata: JSONValue = Field(validation_alias="message_metadata_dict") status: str error: str | None = None parent_message_id: str | None = None @@ -196,7 +196,7 @@ class ModelConfig(ResponseModel): class SimpleModelConfig(ResponseModel): - model: JSONValue | None = None + model: JSONValue | None = Field(default=None, validation_alias="model_dict") pre_prompt: str | None = None @@ -211,6 +211,11 @@ class SimpleMessageDetail(ResponseModel): def _normalize_inputs(cls, value: JSONValue) -> JSONValue: return format_files_contained(value) + @field_validator("message", mode="before") + @classmethod + def _normalize_message(cls, value: JSONValue) -> str: + return message_text(value) + class Conversation(ResponseModel): id: str @@ -227,15 +232,22 @@ class Conversation(ResponseModel): model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config") user_feedback_stats: FeedbackStat | None = None admin_feedback_stats: FeedbackStat | None = None - message: SimpleMessageDetail | None = None + message: SimpleMessageDetail | None = Field(default=None, validation_alias="first_message") + + @field_validator("read_at", "created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value class ConversationPagination(ResponseModel): page: int - limit: int + limit: int = Field(validation_alias="per_page") total: int - has_more: bool - data: list[Conversation] + has_more: bool = Field(validation_alias="has_next") + data: list[Conversation] = Field(validation_alias="items") class ConversationMessageDetail(ResponseModel): @@ -246,7 +258,14 @@ class ConversationMessageDetail(ResponseModel): from_account_id: str | None = None created_at: int | None = None model_config_: ModelConfig | None = Field(default=None, alias="model_config") - message: MessageDetail | None = None + message: MessageDetail | None = Field(default=None, validation_alias="first_message") + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_created_at(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value class ConversationWithSummary(ResponseModel): @@ -258,7 +277,7 @@ class ConversationWithSummary(ResponseModel): from_account_id: str | None = None from_account_name: str | None = None name: str - summary: str + summary: str = Field(validation_alias="summary_or_query") read_at: int | None = None created_at: int | None = None updated_at: int | None = None @@ -269,13 +288,20 @@ class ConversationWithSummary(ResponseModel): admin_feedback_stats: FeedbackStat | None = None status_count: StatusCount | None = None + @field_validator("read_at", "created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + class ConversationWithSummaryPagination(ResponseModel): page: int - limit: int + limit: int = Field(validation_alias="per_page") total: int - has_more: bool - data: list[ConversationWithSummary] + has_more: bool = Field(validation_alias="has_next") + data: list[ConversationWithSummary] = Field(validation_alias="items") class ConversationDetail(ResponseModel): @@ -293,6 +319,13 @@ class ConversationDetail(ResponseModel): user_feedback_stats: FeedbackStat | None = None admin_feedback_stats: FeedbackStat | None = None + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return to_timestamp(value) + return value + def to_timestamp(value: datetime | None) -> int | None: if value is None: diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index c55014a368..e4219ba1ee 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -1,5 +1,13 @@ -from flask_restx import Namespace, fields +from __future__ import annotations +from datetime import datetime +from typing import Any + +from flask_restx import Namespace, fields +from pydantic import field_validator + +from fields.base import ResponseModel +from graphon.variables.types import SegmentType from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type @@ -29,6 +37,74 @@ conversation_variable_infinite_scroll_pagination_fields = { } +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class ConversationVariableResponse(ResponseModel): + id: str + name: str + value_type: str + value: str | None = None + description: str | None = None + created_at: int | None = None + updated_at: int | None = None + + @field_validator("value_type", mode="before") + @classmethod + def _normalize_value_type(cls, value: Any) -> str: + exposed_type = getattr(value, "exposed_type", None) + if callable(exposed_type): + return str(exposed_type()) + if isinstance(value, str): + try: + return str(SegmentType(value).exposed_type()) + except ValueError: + return value + try: + return serialize_value_type(value) + except (AttributeError, TypeError, ValueError): + pass + + try: + return serialize_value_type({"value_type": value}) + except (AttributeError, TypeError, ValueError): + value_attr = getattr(value, "value", None) + if value_attr is not None: + return str(value_attr) + return str(value) + + @field_validator("value", mode="before") + @classmethod + def _normalize_value(cls, value: Any | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + @field_validator("created_at", "updated_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class PaginatedConversationVariableResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[ConversationVariableResponse] + + +class ConversationVariableInfiniteScrollPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[ConversationVariableResponse] + + def build_conversation_variable_model(api_or_ns: Namespace): """Build the conversation variable model for the API or Namespace.""" return api_or_ns.model("ConversationVariable", conversation_variable_fields) diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index cfe0015918..67b320beaa 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,10 +3,10 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields -from graphon.file import helpers as file_helpers from pydantic import computed_field, field_validator from fields.base import ResponseModel +from graphon.file import helpers as file_helpers simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 1a871204a0..ca18f1c203 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,12 +3,12 @@ from __future__ import annotations from datetime import datetime from uuid import uuid4 -from graphon.file import File from pydantic import Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel from fields.base import ResponseModel from fields.conversation_fields import AgentThought, JSONValue, MessageFile +from graphon.file import File type JSONValueType = JSONValue diff --git a/api/fields/online_user_fields.py b/api/fields/online_user_fields.py new file mode 100644 index 0000000000..bdbe19679c --- /dev/null +++ b/api/fields/online_user_fields.py @@ -0,0 +1,16 @@ +from flask_restx import fields + +online_user_partial_fields = { + "user_id": fields.String, + "username": fields.String, + "avatar": fields.String, +} + +workflow_online_users_fields = { + "app_id": fields.String, + "users": fields.List(fields.Nested(online_user_partial_fields)), +} + +online_user_list_fields = { + "data": fields.List(fields.Nested(workflow_online_users_fields)), +} diff --git a/api/fields/raws.py b/api/fields/raws.py index 4c65cdab7a..ee6f53b360 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,4 +1,5 @@ from flask_restx import fields + from graphon.file import File diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py index d0e762f62b..1b2c71255d 100644 --- a/api/fields/workflow_app_log_fields.py +++ b/api/fields/workflow_app_log_fields.py @@ -1,8 +1,17 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields +from datetime import datetime +from typing import Any + +from flask_restx import Namespace, fields +from pydantic import field_validator + +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser, simple_end_user_fields +from fields.member_fields import SimpleAccount, simple_account_fields from fields.workflow_run_fields import ( + WorkflowRunForArchivedLogResponse, + WorkflowRunForLogResponse, build_workflow_run_for_archived_log_model, build_workflow_run_for_log_model, workflow_run_for_archived_log_fields, @@ -85,3 +94,55 @@ def build_workflow_archived_log_pagination_model(api_or_ns: Namespace): copied_fields = workflow_archived_log_pagination_fields.copy() copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model)) return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields) + + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowAppLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForLogResponse | None = None + details: Any = None + created_from: str | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowArchivedLogPartialResponse(ResponseModel): + id: str + workflow_run: WorkflowRunForArchivedLogResponse | None = None + trigger_metadata: Any = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + + @field_validator("created_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowAppLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowAppLogPartialResponse] + + +class WorkflowArchivedLogPaginationResponse(ResponseModel): + page: int + limit: int + total: int + has_more: bool + data: list[WorkflowArchivedLogPartialResponse] diff --git a/api/fields/workflow_comment_fields.py b/api/fields/workflow_comment_fields.py new file mode 100644 index 0000000000..c708dd3460 --- /dev/null +++ b/api/fields/workflow_comment_fields.py @@ -0,0 +1,96 @@ +from flask_restx import fields + +from libs.helper import AvatarUrlField, TimestampField + +# basic account fields for comments +account_fields = { + "id": fields.String, + "name": fields.String, + "email": fields.String, + "avatar_url": AvatarUrlField, +} + +# Comment mention fields +workflow_comment_mention_fields = { + "mentioned_user_id": fields.String, + "mentioned_user_account": fields.Nested(account_fields, allow_null=True), + "reply_id": fields.String, +} + +# Comment reply fields +workflow_comment_reply_fields = { + "id": fields.String, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, +} + +# Basic comment fields (for list views) +workflow_comment_basic_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "reply_count": fields.Integer, + "mention_count": fields.Integer, + "participants": fields.List(fields.Nested(account_fields)), +} + +# Detailed comment fields (for single comment view) +workflow_comment_detail_fields = { + "id": fields.String, + "position_x": fields.Float, + "position_y": fields.Float, + "content": fields.String, + "created_by": fields.String, + "created_by_account": fields.Nested(account_fields, allow_null=True), + "created_at": TimestampField, + "updated_at": TimestampField, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, + "resolved_by_account": fields.Nested(account_fields, allow_null=True), + "replies": fields.List(fields.Nested(workflow_comment_reply_fields)), + "mentions": fields.List(fields.Nested(workflow_comment_mention_fields)), +} + +# Comment creation response fields (simplified) +workflow_comment_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Comment update response fields (simplified) +workflow_comment_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} + +# Comment resolve response fields +workflow_comment_resolve_fields = { + "id": fields.String, + "resolved": fields.Boolean, + "resolved_at": TimestampField, + "resolved_by": fields.String, +} + +# Reply creation response fields (simplified) +workflow_comment_reply_create_fields = { + "id": fields.String, + "created_at": TimestampField, +} + +# Reply update response fields +workflow_comment_reply_update_fields = { + "id": fields.String, + "updated_at": TimestampField, +} diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index b0b6cc0b48..6e947858ba 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,8 +1,8 @@ from flask_restx import fields -from graphon.variables import SecretVariable, SegmentType, VariableBase from core.helper import encrypter from fields.member_fields import simple_account_fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from libs.helper import TimestampField from ._value_type_serializer import serialize_value_type @@ -26,7 +26,7 @@ class EnvironmentVariableField(fields.Raw): "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.exposed_type().value, + "value_type": str(value.value_type.exposed_type()), "description": value.description, } if isinstance(value, dict): diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 35bb442c59..8c659086ed 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,7 +1,14 @@ -from flask_restx import Namespace, fields +from __future__ import annotations -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields +from datetime import datetime +from typing import Any + +from flask_restx import Namespace, fields +from pydantic import Field, field_validator + +from fields.base import ResponseModel +from fields.end_user_fields import SimpleEndUser, simple_end_user_fields +from fields.member_fields import SimpleAccount, simple_account_fields from libs.helper import TimestampField workflow_run_for_log_fields = { @@ -147,3 +154,174 @@ workflow_run_node_execution_fields = { workflow_run_node_execution_list_fields = { "data": fields.List(fields.Nested(workflow_run_node_execution_fields)), } + + +def _to_timestamp(value: datetime | int | None) -> int | None: + if isinstance(value, datetime): + return int(value.timestamp()) + return value + + +class WorkflowRunForLogResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + triggered_from: str | None = None + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowRunForArchivedLogResponse(ResponseModel): + id: str + status: str | None = None + triggered_from: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + +class WorkflowRunForListResponse(ResponseModel): + id: str + version: str | None = None + status: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_by_account: SimpleAccount | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + retry_index: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class AdvancedChatWorkflowRunForListResponse(WorkflowRunForListResponse): + conversation_id: str | None = None + message_id: str | None = None + + +class AdvancedChatWorkflowRunPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[AdvancedChatWorkflowRunForListResponse] + + +class WorkflowRunPaginationResponse(ResponseModel): + limit: int + has_more: bool + data: list[WorkflowRunForListResponse] + + +class WorkflowRunCountResponse(ResponseModel): + total: int + running: int + succeeded: int + failed: int + stopped: int + partial_succeeded: int = Field(validation_alias="partial-succeeded") + + +class WorkflowRunDetailResponse(ResponseModel): + id: str + version: str | None = None + graph: Any = Field(validation_alias="graph_dict") + inputs: Any = Field(validation_alias="inputs_dict") + status: str | None = None + outputs: Any = Field(validation_alias="outputs_dict") + error: str | None = None + elapsed_time: float | None = None + total_tokens: int | None = None + total_steps: int | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + created_at: int | None = None + finished_at: int | None = None + exceptions_count: int | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowRunNodeExecutionResponse(ResponseModel): + id: str + index: int | None = None + predecessor_node_id: str | None = None + node_id: str | None = None + node_type: str | None = None + title: str | None = None + inputs: Any = Field(default=None, validation_alias="inputs_dict") + process_data: Any = Field(default=None, validation_alias="process_data_dict") + outputs: Any = Field(default=None, validation_alias="outputs_dict") + status: str | None = None + error: str | None = None + elapsed_time: float | None = None + execution_metadata: Any = Field(default=None, validation_alias="execution_metadata_dict") + extras: Any = None + created_at: int | None = None + created_by_role: str | None = None + created_by_account: SimpleAccount | None = None + created_by_end_user: SimpleEndUser | None = None + finished_at: int | None = None + inputs_truncated: bool | None = None + outputs_truncated: bool | None = None + process_data_truncated: bool | None = None + + @field_validator("status", mode="before") + @classmethod + def _normalize_status(cls, value: Any) -> str | None: + if value is None or isinstance(value, str): + return value + return str(getattr(value, "value", value)) + + @field_validator("created_at", "finished_at", mode="before") + @classmethod + def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: + return _to_timestamp(value) + + +class WorkflowRunNodeExecutionListResponse(ResponseModel): + data: list[WorkflowRunNodeExecutionResponse] diff --git a/api/libs/broadcast_channel/redis/_subscription.py b/api/libs/broadcast_channel/redis/_subscription.py index 40027bc424..4db79a15a9 100644 --- a/api/libs/broadcast_channel/redis/_subscription.py +++ b/api/libs/broadcast_channel/redis/_subscription.py @@ -3,7 +3,7 @@ import queue import threading import types from collections.abc import Generator, Iterator -from typing import Self +from typing import Any, Self from libs.broadcast_channel.channel import Subscription from libs.broadcast_channel.exc import SubscriptionClosedError @@ -221,7 +221,7 @@ class RedisSubscriptionBase(Subscription): """Unsubscribe from the Redis topic using the appropriate command.""" raise NotImplementedError - def _get_message(self) -> dict | None: + def _get_message(self) -> dict[str, Any] | None: """Get a message from Redis using the appropriate method.""" raise NotImplementedError diff --git a/api/libs/broadcast_channel/redis/channel.py b/api/libs/broadcast_channel/redis/channel.py index bd6d58c53f..b76a23eb3c 100644 --- a/api/libs/broadcast_channel/redis/channel.py +++ b/api/libs/broadcast_channel/redis/channel.py @@ -1,5 +1,8 @@ from __future__ import annotations +from typing import Any + +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis, RedisCluster @@ -30,12 +33,13 @@ class Topic: def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic + self._redis_topic = serialize_redis_name(topic) def as_producer(self) -> Producer: return self def publish(self, payload: bytes) -> None: - self._client.publish(self._topic, payload) + self._client.publish(self._redis_topic, payload) def as_subscriber(self) -> Subscriber: return self @@ -44,7 +48,7 @@ class Topic: return _RedisSubscription( client=self._client, pubsub=self._client.pubsub(), - topic=self._topic, + topic=self._redis_topic, ) @@ -62,7 +66,7 @@ class _RedisSubscription(RedisSubscriptionBase): assert self._pubsub is not None self._pubsub.unsubscribe(self._topic) - def _get_message(self) -> dict | None: + def _get_message(self) -> dict[str, Any] | None: assert self._pubsub is not None return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1) diff --git a/api/libs/broadcast_channel/redis/sharded_channel.py b/api/libs/broadcast_channel/redis/sharded_channel.py index 20c43b8bbb..919d8d622e 100644 --- a/api/libs/broadcast_channel/redis/sharded_channel.py +++ b/api/libs/broadcast_channel/redis/sharded_channel.py @@ -1,5 +1,8 @@ from __future__ import annotations +from typing import Any + +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from redis import Redis, RedisCluster @@ -28,12 +31,13 @@ class ShardedTopic: def __init__(self, redis_client: Redis | RedisCluster, topic: str): self._client = redis_client self._topic = topic + self._redis_topic = serialize_redis_name(topic) def as_producer(self) -> Producer: return self def publish(self, payload: bytes) -> None: - self._client.spublish(self._topic, payload) # type: ignore[attr-defined,union-attr] + self._client.spublish(self._redis_topic, payload) # type: ignore[attr-defined,union-attr] def as_subscriber(self) -> Subscriber: return self @@ -42,7 +46,7 @@ class ShardedTopic: return _RedisShardedSubscription( client=self._client, pubsub=self._client.pubsub(), - topic=self._topic, + topic=self._redis_topic, ) @@ -60,7 +64,7 @@ class _RedisShardedSubscription(RedisSubscriptionBase): assert self._pubsub is not None self._pubsub.sunsubscribe(self._topic) # type: ignore[attr-defined] - def _get_message(self) -> dict | None: + def _get_message(self) -> dict[str, Any] | None: assert self._pubsub is not None # NOTE(QuantumGhost): this is an issue in # upstream code. If Sharded PubSub is used with Cluster, the diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py index 983f785027..55ff6cd4f9 100644 --- a/api/libs/broadcast_channel/redis/streams_channel.py +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -6,6 +6,7 @@ import threading from collections.abc import Iterator from typing import Self +from extensions.redis_names import serialize_redis_name from libs.broadcast_channel.channel import Producer, Subscriber, Subscription from libs.broadcast_channel.exc import SubscriptionClosedError from redis import Redis, RedisCluster @@ -35,7 +36,7 @@ class StreamsTopic: def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600): self._client = redis_client self._topic = topic - self._key = f"stream:{topic}" + self._key = serialize_redis_name(f"stream:{topic}") self._retention_seconds = retention_seconds self.max_length = 5000 diff --git a/api/libs/db_migration_lock.py b/api/libs/db_migration_lock.py index 1d3a81e0a2..b5fe38342a 100644 --- a/api/libs/db_migration_lock.py +++ b/api/libs/db_migration_lock.py @@ -14,9 +14,15 @@ from __future__ import annotations import logging import threading -from typing import Any +from typing import TYPE_CHECKING, Any +import redis +from redis.cluster import RedisCluster from redis.exceptions import LockNotOwnedError, RedisError +from redis.lock import Lock + +if TYPE_CHECKING: + from extensions.ext_redis import RedisClientWrapper logger = logging.getLogger(__name__) @@ -38,21 +44,21 @@ class DbMigrationAutoRenewLock: primary error/exit code. """ - _redis_client: Any + _redis_client: redis.Redis | RedisCluster | RedisClientWrapper _name: str _ttl_seconds: float _renew_interval_seconds: float _log_context: str | None _logger: logging.Logger - _lock: Any + _lock: Lock | None _stop_event: threading.Event | None _thread: threading.Thread | None _acquired: bool def __init__( self, - redis_client: Any, + redis_client: redis.Redis | RedisCluster | RedisClientWrapper, name: str, ttl_seconds: float = 60, renew_interval_seconds: float | None = None, @@ -97,7 +103,10 @@ class DbMigrationAutoRenewLock: timeout=self._ttl_seconds, thread_local=False, ) - acquired = bool(self._lock.acquire(*args, **kwargs)) + lock = self._lock + if lock is None: + raise RuntimeError("Redis lock initialization failed.") + acquired = bool(lock.acquire(*args, **kwargs)) self._acquired = acquired if acquired: self._start_heartbeat() @@ -127,7 +136,7 @@ class DbMigrationAutoRenewLock: ) self._thread.start() - def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None: + def _heartbeat_loop(self, lock: Lock, stop_event: threading.Event) -> None: while not stop_event.wait(self._renew_interval_seconds): try: lock.reacquire() diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py index 0828cf80bf..1519f07bb1 100644 --- a/api/libs/email_i18n.py +++ b/api/libs/email_i18n.py @@ -37,6 +37,7 @@ class EmailType(StrEnum): ENTERPRISE_CUSTOM = auto() QUEUE_MONITOR_ALERT = auto() DOCUMENT_CLEAN_NOTIFY = auto() + WORKFLOW_COMMENT_MENTION = auto() EMAIL_REGISTER = auto() EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto() RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto() @@ -453,6 +454,18 @@ def create_default_email_config() -> EmailI18nConfig: branded_template_path="clean_document_job_mail_template_zh-CN.html", ), }, + EmailType.WORKFLOW_COMMENT_MENTION: { + EmailLanguage.EN_US: EmailTemplate( + subject="You were mentioned in a workflow comment", + template_path="workflow_comment_mention_template_en-US.html", + branded_template_path="without-brand/workflow_comment_mention_template_en-US.html", + ), + EmailLanguage.ZH_HANS: EmailTemplate( + subject="你在工作流评论中被提及", + template_path="workflow_comment_mention_template_zh-CN.html", + branded_template_path="without-brand/workflow_comment_mention_template_zh-CN.html", + ), + }, EmailType.TRIGGER_EVENTS_LIMIT_SANDBOX: { EmailLanguage.EN_US: EmailTemplate( subject="You’ve reached your Sandbox Trigger Events limit", diff --git a/api/libs/exception.py b/api/libs/exception.py index 73379dfded..1e4bbb44f6 100644 --- a/api/libs/exception.py +++ b/api/libs/exception.py @@ -1,9 +1,11 @@ +from typing import Any + from werkzeug.exceptions import HTTPException class BaseHTTPException(HTTPException): error_code: str = "unknown" - data: dict | None = None + data: dict[str, Any] | None = None def __init__(self, description=None, response=None): super().__init__(description, response) diff --git a/api/libs/external_api.py b/api/libs/external_api.py index e8592407c3..f907d17750 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -17,7 +17,6 @@ def http_status_message(code): def register_external_error_handlers(api: Api): - @api.errorhandler(HTTPException) def handle_http_exception(e: HTTPException): got_request_exception.send(current_app, exception=e) @@ -74,27 +73,18 @@ def register_external_error_handlers(api: Api): headers["Set-Cookie"] = build_force_logout_cookie_headers() return data, status_code, headers - _ = handle_http_exception - - @api.errorhandler(ValueError) def handle_value_error(e: ValueError): got_request_exception.send(current_app, exception=e) status_code = 400 data = {"code": "invalid_param", "message": str(e), "status": status_code} return data, status_code - _ = handle_value_error - - @api.errorhandler(AppInvokeQuotaExceededError) def handle_quota_exceeded(e: AppInvokeQuotaExceededError): got_request_exception.send(current_app, exception=e) status_code = 429 data = {"code": "too_many_requests", "message": str(e), "status": status_code} return data, status_code - _ = handle_quota_exceeded - - @api.errorhandler(Exception) def handle_general_exception(e: Exception): got_request_exception.send(current_app, exception=e) @@ -113,7 +103,10 @@ def register_external_error_handlers(api: Api): return data, status_code - _ = handle_general_exception + api.errorhandler(HTTPException)(handle_http_exception) + api.errorhandler(ValueError)(handle_value_error) + api.errorhandler(AppInvokeQuotaExceededError)(handle_quota_exceeded) + api.errorhandler(Exception)(handle_general_exception) class ExternalApi(Api): diff --git a/api/libs/flask_utils.py b/api/libs/flask_utils.py index 52fc787c79..838af2bf32 100644 --- a/api/libs/flask_utils.py +++ b/api/libs/flask_utils.py @@ -1,5 +1,5 @@ import contextvars -from collections.abc import Iterator +from collections.abc import Generator # Changed from Iterator from contextlib import contextmanager from typing import TYPE_CHECKING @@ -13,7 +13,7 @@ if TYPE_CHECKING: def preserve_flask_contexts( flask_app: Flask, context_vars: contextvars.Context, -) -> Iterator[None]: +) -> Generator[None, None, None]: # Changed from Iterator[None] """ A context manager that handles: 1. flask-login's UserProxy copy diff --git a/api/libs/helper.py b/api/libs/helper.py index e7decd43b3..ac69a11084 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,8 +16,6 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields -from graphon.file import helpers as file_helpers -from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, TypeAdapter from pydantic.functional_validators import AfterValidator from typing_extensions import TypedDict @@ -25,6 +23,8 @@ from typing_extensions import TypedDict from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder if TYPE_CHECKING: from models import Account @@ -120,10 +120,22 @@ class AppIconUrlField(fields.Raw): obj = obj["app"] if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE: - return file_helpers.get_signed_file_url(obj.icon) + return build_icon_url(obj.icon_type, obj.icon) return None +def build_icon_url(icon_type: Any, icon: str | None) -> str | None: + if icon is None or icon_type is None: + return None + + from models.model import IconType + + icon_type_value = icon_type.value if isinstance(icon_type, IconType) else str(icon_type) + if icon_type_value.lower() != IconType.IMAGE: + return None + return file_helpers.get_signed_file_url(icon) + + class AvatarUrlField(fields.Raw): def output(self, key, obj, **kwargs): if obj is None: @@ -410,7 +422,7 @@ class TokenManager: token_type: str, account: "Account | None" = None, email: str | None = None, - additional_data: dict | None = None, + additional_data: dict[str, Any] | None = None, ) -> str: if account is None and email is None: raise ValueError("Account or email must be provided") diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 9b53918f24..934aacb45b 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -6,8 +6,8 @@ from flask_login import current_user from pydantic import TypeAdapter from sqlalchemy import select +from core.db.session_factory import session_factory from core.helper.http_client_pooling import get_pooled_http_client -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -95,27 +95,28 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def save_internal_access_token(self, access_token: str) -> None: workspace_name = self.notion_workspace_name(access_token) @@ -130,55 +131,57 @@ class NotionOAuth(OAuthDataSource): pages=pages, ) # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ) - if data_source_binding: - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - new_data_source_binding = DataSourceOauthBinding( - tenant_id=current_user.current_tenant_id, - access_token=access_token, - source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), - provider="notion", - ) - db.session.add(new_data_source_binding) - db.session.commit() + if data_source_binding: + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + new_data_source_binding = DataSourceOauthBinding( + tenant_id=current_user.current_tenant_id, + access_token=access_token, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), + provider="notion", + ) + session.add(new_data_source_binding) + session.commit() def sync_data_source(self, binding_id: str) -> None: # save data source binding - data_source_binding = db.session.scalar( - select(DataSourceOauthBinding).where( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False, + with session_factory.create_session() as session: + data_source_binding = session.scalar( + select(DataSourceOauthBinding).where( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False, + ) ) - ) - if data_source_binding: - # get all authorized pages - pages = self.get_authorized_pages(data_source_binding.access_token) - source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) - new_source_info = self._build_source_info( - workspace_name=source_info["workspace_name"], - workspace_icon=source_info["workspace_icon"], - workspace_id=source_info["workspace_id"], - pages=pages, - ) - data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.commit() - else: - raise ValueError("Data source binding not found") + if data_source_binding: + # get all authorized pages + pages = self.get_authorized_pages(data_source_binding.access_token) + source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) + new_source_info = self._build_source_info( + workspace_name=source_info["workspace_name"], + workspace_icon=source_info["workspace_icon"], + workspace_id=source_info["workspace_id"], + pages=pages, + ) + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + session.commit() + else: + raise ValueError("Data source binding not found") def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]: pages: list[NotionPageSummary] = [] diff --git a/api/libs/pyrefly_type_coverage.py b/api/libs/pyrefly_type_coverage.py new file mode 100644 index 0000000000..369b8dff3c --- /dev/null +++ b/api/libs/pyrefly_type_coverage.py @@ -0,0 +1,145 @@ +"""Helpers for generating type-coverage summaries from pyrefly report output.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import TypedDict + + +class CoverageSummary(TypedDict): + n_modules: int + n_typable: int + n_typed: int + n_any: int + n_untyped: int + coverage: float + strict_coverage: float + + +_REQUIRED_KEYS = frozenset(CoverageSummary.__annotations__) + +_EMPTY_SUMMARY: CoverageSummary = { + "n_modules": 0, + "n_typable": 0, + "n_typed": 0, + "n_any": 0, + "n_untyped": 0, + "coverage": 0.0, + "strict_coverage": 0.0, +} + + +def parse_summary(report_json: str) -> CoverageSummary: + """Extract the summary section from ``pyrefly report`` JSON output. + + Returns an empty summary when *report_json* is empty or malformed so that + the CI workflow can degrade gracefully instead of crashing. + """ + if not report_json or not report_json.strip(): + return _EMPTY_SUMMARY.copy() + + try: + data = json.loads(report_json) + except json.JSONDecodeError: + return _EMPTY_SUMMARY.copy() + + summary = data.get("summary") + if not isinstance(summary, dict) or not _REQUIRED_KEYS.issubset(summary): + return _EMPTY_SUMMARY.copy() + + return { + "n_modules": summary["n_modules"], + "n_typable": summary["n_typable"], + "n_typed": summary["n_typed"], + "n_any": summary["n_any"], + "n_untyped": summary["n_untyped"], + "coverage": summary["coverage"], + "strict_coverage": summary["strict_coverage"], + } + + +def format_summary_markdown(summary: CoverageSummary) -> str: + """Format a single coverage summary as a Markdown table.""" + + return ( + "| Metric | Value |\n" + "| --- | ---: |\n" + f"| Modules | {summary['n_modules']} |\n" + f"| Typable symbols | {summary['n_typable']:,} |\n" + f"| Typed symbols | {summary['n_typed']:,} |\n" + f"| Untyped symbols | {summary['n_untyped']:,} |\n" + f"| Any symbols | {summary['n_any']:,} |\n" + f"| **Type coverage** | **{summary['coverage']:.2f}%** |\n" + f"| Strict coverage | {summary['strict_coverage']:.2f}% |" + ) + + +def format_comparison_markdown( + base: CoverageSummary, + pr: CoverageSummary, +) -> str: + """Format a comparison between base and PR coverage as Markdown.""" + + coverage_delta = pr["coverage"] - base["coverage"] + strict_delta = pr["strict_coverage"] - base["strict_coverage"] + typed_delta = pr["n_typed"] - base["n_typed"] + untyped_delta = pr["n_untyped"] - base["n_untyped"] + + def _fmt_delta(value: float, fmt: str = ".2f") -> str: + sign = "+" if value > 0 else "" + return f"{sign}{value:{fmt}}" + + lines = [ + "| Metric | Base | PR | Delta |", + "| --- | ---: | ---: | ---: |", + (f"| **Type coverage** | {base['coverage']:.2f}% | {pr['coverage']:.2f}% | {_fmt_delta(coverage_delta)}% |"), + ( + f"| Strict coverage | {base['strict_coverage']:.2f}% " + f"| {pr['strict_coverage']:.2f}% " + f"| {_fmt_delta(strict_delta)}% |" + ), + (f"| Typed symbols | {base['n_typed']:,} | {pr['n_typed']:,} | {_fmt_delta(typed_delta, ',')} |"), + (f"| Untyped symbols | {base['n_untyped']:,} | {pr['n_untyped']:,} | {_fmt_delta(untyped_delta, ',')} |"), + ( + f"| Modules | {base['n_modules']} " + f"| {pr['n_modules']} " + f"| {_fmt_delta(pr['n_modules'] - base['n_modules'], ',')} |" + ), + ] + return "\n".join(lines) + + +def main() -> int: + """Read pyrefly report JSON from stdin and print a Markdown summary. + + Accepts an optional ``--base `` argument. When provided, the output + includes a base-vs-PR comparison table. + """ + + args = sys.argv[1:] + + base_file: str | None = None + if "--base" in args: + idx = args.index("--base") + if idx + 1 >= len(args): + sys.stderr.write("error: --base requires a file path\n") + return 1 + base_file = args[idx + 1] + + pr_report = sys.stdin.read() + pr_summary = parse_summary(pr_report) + + if base_file is not None: + base_text = Path(base_file).read_text() if Path(base_file).exists() else "" + base_summary = parse_summary(base_text) + sys.stdout.write(format_comparison_markdown(base_summary, pr_summary) + "\n") + else: + sys.stdout.write(format_summary_markdown(pr_summary) + "\n") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py index c047c54d06..0338641d11 100644 --- a/api/libs/sendgrid.py +++ b/api/libs/sendgrid.py @@ -1,4 +1,5 @@ import logging +from typing import Any import sendgrid from python_http_client.exceptions import ForbiddenError, UnauthorizedError @@ -12,7 +13,7 @@ class SendGridClient: self.sendgrid_api_key = sendgrid_api_key self._from = _from - def send(self, mail: dict): + def send(self, mail: dict[str, Any]): logger.debug("Sending email with SendGrid") _to = "" try: diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 6f82f1440a..53906d1769 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -2,6 +2,7 @@ import logging import smtplib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from typing import Any from configs import dify_config @@ -20,7 +21,7 @@ class SMTPClient: self.use_tls = use_tls self.opportunistic_tls = opportunistic_tls - def send(self, mail: dict): + def send(self, mail: dict[str, Any]): smtp: smtplib.SMTP | None = None local_host = dify_config.SMTP_LOCAL_HOSTNAME try: diff --git a/api/libs/token.py b/api/libs/token.py index a34db70764..5b043465ac 100644 --- a/api/libs/token.py +++ b/api/libs/token.py @@ -47,23 +47,17 @@ def _cookie_domain() -> str | None: def _real_cookie_name(cookie_name: str) -> str: if is_secure() and _cookie_domain() is None: return "__Host-" + cookie_name - else: - return cookie_name + return cookie_name def _try_extract_from_header(request: Request) -> str | None: auth_header = request.headers.get("Authorization") - if auth_header: - if " " not in auth_header: - return None - else: - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - return None - else: - return auth_token - return None + if not auth_header or " " not in auth_header: + return None + auth_scheme, auth_token = auth_header.split(None, 1) + if auth_scheme.lower() != "bearer": + return None + return auth_token def extract_refresh_token(request: Request) -> str | None: @@ -90,14 +84,9 @@ def extract_webapp_access_token(request: Request) -> str | None: def extract_webapp_passport(app_code: str, request: Request) -> str | None: - def _try_extract_passport_token_from_cookie(request: Request) -> str | None: - return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) - - def _try_extract_passport_token_from_header(request: Request) -> str | None: - return request.headers.get(HEADER_NAME_PASSPORT) - - ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request) - return ret + return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code)) or request.headers.get( + HEADER_NAME_PASSPORT + ) def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"): @@ -209,22 +198,18 @@ def check_csrf_token(request: Request, user_id: str): if not csrf_token: _unauthorized() - verified = {} try: verified = PassportService().verify(csrf_token) - except: + except Exception: _unauthorized() + raise # unreachable, but helps the type checker see verified is always bound if verified.get("sub") != user_id: _unauthorized() exp: int | None = verified.get("exp") - if not exp: + if not exp or exp < int(datetime.now(UTC).timestamp()): _unauthorized() - else: - time_now = int(datetime.now().timestamp()) - if exp < time_now: - _unauthorized() def generate_csrf_token(user_id: str) -> str: diff --git a/api/libs/typing.py b/api/libs/typing.py deleted file mode 100644 index f84e9911e0..0000000000 --- a/api/libs/typing.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import TypeGuard - - -def is_str_dict(v: object) -> TypeGuard[dict[str, object]]: - return isinstance(v, dict) - - -def is_str(v: object) -> TypeGuard[str]: - return isinstance(v, str) diff --git a/api/libs/url_utils.py b/api/libs/url_utils.py new file mode 100644 index 0000000000..adcac3add0 --- /dev/null +++ b/api/libs/url_utils.py @@ -0,0 +1,3 @@ +def normalize_api_base_url(base_url: str) -> str: + """Normalize a base URL to always end with /v1, avoiding double /v1 suffixes.""" + return base_url.rstrip("/").removesuffix("/v1").rstrip("/") + "/v1" diff --git a/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py b/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py new file mode 100644 index 0000000000..0e188ec080 --- /dev/null +++ b/api/migrations/versions/2026_04_14_1500-8574b23a38fd_add_qdrant_endpoint_to_tidb_auth_bindings.py @@ -0,0 +1,26 @@ +"""add qdrant_endpoint to tidb_auth_bindings + +Revision ID: 8574b23a38fd +Revises: 6b5f9f8b1a2c +Create Date: 2026-04-14 15:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8574b23a38fd" +down_revision = "6b5f9f8b1a2c" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op: + batch_op.add_column(sa.Column("qdrant_endpoint", sa.String(length=512), nullable=True)) + + +def downgrade(): + with op.batch_alter_table("tidb_auth_bindings", schema=None) as batch_op: + batch_op.drop_column("qdrant_endpoint") diff --git a/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py b/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py new file mode 100644 index 0000000000..0548c932b5 --- /dev/null +++ b/api/migrations/versions/2026_04_15_1726-227822d22895_add_workflow_comments_table.py @@ -0,0 +1,90 @@ +"""Add workflow comments table + +Revision ID: 227822d22895 +Revises: 8574b23a38fd +Create Date: 2025-08-22 17:26:15.255980 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '227822d22895' +down_revision = '8574b23a38fd' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('workflow_comments', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('position_x', sa.Float(), nullable=False), + sa.Column('position_y', sa.Float(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('resolved_at', sa.DateTime(), nullable=True), + sa.Column('resolved_by', models.types.StringUUID(), nullable=True), + sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey') + ) + with op.batch_alter_table('workflow_comments', schema=None) as batch_op: + batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False) + batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False) + + op.create_table('workflow_comment_replies', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('comment_id', models.types.StringUUID(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey') + ) + with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op: + batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False) + batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False) + + op.create_table('workflow_comment_mentions', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('comment_id', models.types.StringUUID(), nullable=False), + sa.Column('reply_id', models.types.StringUUID(), nullable=True), + sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False), + sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'), + sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey') + ) + with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op: + batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False) + batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False) + batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op: + batch_op.drop_index('comment_mentions_user_idx') + batch_op.drop_index('comment_mentions_reply_idx') + batch_op.drop_index('comment_mentions_comment_idx') + + op.drop_table('workflow_comment_mentions') + with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op: + batch_op.drop_index('comment_replies_created_at_idx') + batch_op.drop_index('comment_replies_comment_idx') + + op.drop_table('workflow_comment_replies') + with op.batch_alter_table('workflow_comments', schema=None) as batch_op: + batch_op.drop_index('workflow_comments_created_at_idx') + batch_op.drop_index('workflow_comments_app_idx') + + op.drop_table('workflow_comments') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index fcae07f948..85be9ca3bd 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -9,6 +9,11 @@ from .account import ( TenantStatus, ) from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from .comment import ( + WorkflowComment, + WorkflowCommentMention, + WorkflowCommentReply, +) from .dataset import ( AppDatasetJoin, Dataset, @@ -208,6 +213,9 @@ __all__ = [ "WorkflowAppLog", "WorkflowAppLogCreatedFrom", "WorkflowArchiveLog", + "WorkflowComment", + "WorkflowCommentMention", + "WorkflowCommentReply", "WorkflowNodeExecutionModel", "WorkflowNodeExecutionOffload", "WorkflowNodeExecutionTriggeredFrom", diff --git a/api/models/base.py b/api/models/base.py index b7023b9c8b..5acdf184f4 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -24,6 +24,8 @@ class TypeBase(MappedAsDataclass, DeclarativeBase): class DefaultFieldsMixin: + """Mixin for models that inherit from Base (non-dataclass).""" + id: Mapped[str] = mapped_column( StringUUID, primary_key=True, @@ -53,6 +55,42 @@ class DefaultFieldsMixin: return f"<{self.__class__.__name__}(id={self.id})>" +class DefaultFieldsDCMixin(MappedAsDataclass): + """Mixin for models that inherit from TypeBase (MappedAsDataclass).""" + + __abstract__ = True + + id: Mapped[str] = mapped_column( + StringUUID, + primary_key=True, + insert_default=lambda: str(uuidv7()), + default_factory=lambda: str(uuidv7()), + init=False, + ) + + created_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + insert_default=naive_utc_now, + default_factory=naive_utc_now, + init=False, + server_default=func.current_timestamp(), + ) + + updated_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + insert_default=naive_utc_now, + default_factory=naive_utc_now, + init=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(id={self.id})>" + + def gen_uuidv4_string() -> str: """gen_uuidv4_string generate a UUIDv4 string. diff --git a/api/models/comment.py b/api/models/comment.py new file mode 100644 index 0000000000..5d4a08e783 --- /dev/null +++ b/api/models/comment.py @@ -0,0 +1,219 @@ +"""Workflow comment models.""" + +from datetime import datetime +from typing import Optional + +import sqlalchemy as sa +from sqlalchemy import Index, func +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .account import Account +from .base import Base, gen_uuidv7_string +from .engine import db +from .types import StringUUID + + +class WorkflowComment(Base): + """Workflow comment model for canvas commenting functionality. + + Comments are associated with apps rather than specific workflow versions, + since an app has only one draft workflow at a time and comments should persist + across workflow version changes. + + Attributes: + id: Comment ID + tenant_id: Workspace ID + app_id: App ID (primary association, comments belong to apps) + position_x: X coordinate on canvas + position_y: Y coordinate on canvas + content: Comment content + created_by: Creator account ID + created_at: Creation time + updated_at: Last update time + resolved: Whether comment is resolved + resolved_at: Resolution time + resolved_by: Resolver account ID + """ + + __tablename__ = "workflow_comments" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"), + Index("workflow_comments_app_idx", "tenant_id", "app_id"), + Index("workflow_comments_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + position_x: Mapped[float] = mapped_column(sa.Float) + position_y: Mapped[float] = mapped_column(sa.Float) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime) + resolved_by: Mapped[str | None] = mapped_column(StringUUID) + + # Relationships + replies: Mapped[list["WorkflowCommentReply"]] = relationship( + "WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan" + ) + mentions: Mapped[list["WorkflowCommentMention"]] = relationship( + "WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan" + ) + + @property + def created_by_account(self): + """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache + return db.session.get(Account, self.created_by) + + def cache_created_by_account(self, account: Account | None) -> None: + """Cache creator account to avoid extra queries.""" + self._created_by_account_cache = account + + @property + def resolved_by_account(self): + """Get resolver account.""" + if hasattr(self, "_resolved_by_account_cache"): + return self._resolved_by_account_cache + if self.resolved_by: + return db.session.get(Account, self.resolved_by) + return None + + def cache_resolved_by_account(self, account: Account | None) -> None: + """Cache resolver account to avoid extra queries.""" + self._resolved_by_account_cache = account + + @property + def reply_count(self): + """Get reply count.""" + return len(self.replies) + + @property + def mention_count(self): + """Get mention count.""" + return len(self.mentions) + + @property + def participants(self): + """Get all participants (creator + repliers + mentioned users).""" + participant_ids: set[str] = set() + participants: list[Account] = [] + + # Use account properties to reuse preloaded caches and avoid hidden N+1. + if self.created_by not in participant_ids: + participant_ids.add(self.created_by) + created_by_account = self.created_by_account + if created_by_account: + participants.append(created_by_account) + + for reply in self.replies: + if reply.created_by in participant_ids: + continue + participant_ids.add(reply.created_by) + reply_account = reply.created_by_account + if reply_account: + participants.append(reply_account) + + for mention in self.mentions: + if mention.mentioned_user_id in participant_ids: + continue + participant_ids.add(mention.mentioned_user_id) + mentioned_account = mention.mentioned_user_account + if mentioned_account: + participants.append(mentioned_account) + + return participants + + +class WorkflowCommentReply(Base): + """Workflow comment reply model. + + Attributes: + id: Reply ID + comment_id: Parent comment ID + content: Reply content + created_by: Creator account ID + created_at: Creation time + """ + + __tablename__ = "workflow_comment_replies" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"), + Index("comment_replies_comment_idx", "comment_id"), + Index("comment_replies_created_at_idx", "created_at"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + comment_id: Mapped[str] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + content: Mapped[str] = mapped_column(sa.Text, nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies") + + @property + def created_by_account(self): + """Get creator account.""" + if hasattr(self, "_created_by_account_cache"): + return self._created_by_account_cache + return db.session.get(Account, self.created_by) + + def cache_created_by_account(self, account: Account | None) -> None: + """Cache creator account to avoid extra queries.""" + self._created_by_account_cache = account + + +class WorkflowCommentMention(Base): + """Workflow comment mention model. + + Mentions are only for internal accounts since end users + cannot access workflow canvas and commenting features. + + Attributes: + id: Mention ID + comment_id: Parent comment ID + mentioned_user_id: Mentioned account ID + """ + + __tablename__ = "workflow_comment_mentions" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"), + Index("comment_mentions_comment_idx", "comment_id"), + Index("comment_mentions_reply_idx", "reply_id"), + Index("comment_mentions_user_idx", "mentioned_user_id"), + ) + + id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + comment_id: Mapped[str] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False + ) + reply_id: Mapped[str | None] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True + ) + mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # Relationships + comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions") + reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply") + + @property + def mentioned_user_account(self): + """Get mentioned account.""" + if hasattr(self, "_mentioned_user_account_cache"): + return self._mentioned_user_account_cache + return db.session.get(Account, self.mentioned_user_id) + + def cache_mentioned_user_account(self, account: Account | None) -> None: + """Cache mentioned account to avoid extra queries.""" + self._mentioned_user_account_cache = account diff --git a/api/models/dataset.py b/api/models/dataset.py index 97604848af..a00e9f7640 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -108,6 +108,56 @@ class ExternalKnowledgeApiDict(TypedDict): created_at: str +class DocumentDict(TypedDict): + id: str + tenant_id: str + dataset_id: str + position: int + data_source_type: str + data_source_info: str | None + dataset_process_rule_id: str | None + batch: str + name: str + created_from: str + created_by: str + created_api_request_id: str | None + created_at: datetime + processing_started_at: datetime | None + file_id: str | None + word_count: int | None + parsing_completed_at: datetime | None + cleaning_completed_at: datetime | None + splitting_completed_at: datetime | None + tokens: int | None + indexing_latency: float | None + completed_at: datetime | None + is_paused: bool | None + paused_by: str | None + paused_at: datetime | None + error: str | None + stopped_at: datetime | None + indexing_status: str + enabled: bool + disabled_at: datetime | None + disabled_by: str | None + archived: bool + archived_reason: str | None + archived_by: str | None + archived_at: datetime | None + updated_at: datetime + doc_type: str | None + doc_metadata: Any + doc_form: IndexStructureType + doc_language: str | None + display_status: str | None + data_source_info_dict: dict[str, Any] + average_segment_length: int + dataset_process_rule: ProcessRuleDict | None + dataset: None + segment_count: int | None + hit_count: int | None + + class DatasetPermissionEnum(enum.StrEnum): ONLY_ME = "only_me" ALL_TEAM = "all_team_members" @@ -303,13 +353,17 @@ class Dataset(Base): if self.provider != "external": return None external_knowledge_binding = db.session.scalar( - select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id) + select(ExternalKnowledgeBindings).where( + ExternalKnowledgeBindings.dataset_id == self.id, + ExternalKnowledgeBindings.tenant_id == self.tenant_id, + ) ) if not external_knowledge_binding: return None external_knowledge_api = db.session.scalar( select(ExternalKnowledgeApis).where( - ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id + ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id, + ExternalKnowledgeApis.tenant_id == self.tenant_id, ) ) if external_knowledge_api is None or external_knowledge_api.settings is None: @@ -675,8 +729,8 @@ class Document(Base): ) return built_in_fields - def to_dict(self) -> dict[str, Any]: - return { + def to_dict(self) -> DocumentDict: + result: DocumentDict = { "id": self.id, "tenant_id": self.tenant_id, "dataset_id": self.dataset_id, @@ -721,10 +775,11 @@ class Document(Base): "data_source_info_dict": self.data_source_info_dict, "average_segment_length": self.average_segment_length, "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, - "dataset": None, # Dataset class doesn't have a to_dict method + "dataset": None, "segment_count": self.segment_count, "hit_count": self.hit_count, } + return result @classmethod def from_dict(cls, data: dict[str, Any]): @@ -981,7 +1036,7 @@ class DocumentSegment(Base): return attachment_list -class ChildChunk(Base): +class ChildChunk(TypeBase): __tablename__ = "child_chunks" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="child_chunk_pkey"), @@ -991,29 +1046,42 @@ class ChildChunk(Base): ) # initial fields - id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) - tenant_id = mapped_column(StringUUID, nullable=False) - dataset_id = mapped_column(StringUUID, nullable=False) - document_id = mapped_column(StringUUID, nullable=False) - segment_id = mapped_column(StringUUID, nullable=False) + id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - content = mapped_column(LongText, nullable=False) + content: Mapped[str] = mapped_column(LongText, nullable=False) word_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) # indexing fields - index_node_id = mapped_column(String(255), nullable=True) - index_node_hash = mapped_column(String(255), nullable=True) - type: Mapped[SegmentType] = mapped_column( - EnumText(SegmentType, length=255), nullable=False, server_default=sa.text("'automatic'") + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False ) - created_by = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) - updated_by = mapped_column(StringUUID, nullable=True) + updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, init=False) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=sa.func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, + nullable=False, + server_default=sa.func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) - indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - error = mapped_column(LongText, nullable=True) + indexing_at: Mapped[datetime | None] = mapped_column( + DateTime, nullable=True, insert_default=None, server_default=None, init=False + ) + completed_at: Mapped[datetime | None] = mapped_column( + DateTime, nullable=True, insert_default=None, server_default=None, init=False + ) + index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + type: Mapped[SegmentType] = mapped_column( + EnumText(SegmentType, length=255), + nullable=False, + server_default=sa.text("'automatic'"), + default=SegmentType.AUTOMATIC, + ) + error: Mapped[str | None] = mapped_column(LongText, nullable=True, init=False) @property def dataset(self): @@ -1250,6 +1318,7 @@ class TidbAuthBinding(TypeBase): ) account: Mapped[str] = mapped_column(String(255), nullable=False) password: Mapped[str] = mapped_column(String(255), nullable=False) + qdrant_endpoint: Mapped[str | None] = mapped_column(String(512), nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -1496,7 +1565,7 @@ class PipelineBuiltInTemplate(TypeBase): name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column(LongText, nullable=False) chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False) - icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) yaml_content: Mapped[str] = mapped_column(LongText, nullable=False) copyright: Mapped[str] = mapped_column(sa.String(255), nullable=False) privacy_policy: Mapped[str] = mapped_column(sa.String(255), nullable=False) @@ -1529,7 +1598,7 @@ class PipelineCustomizedTemplate(TypeBase): name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column(LongText, nullable=False) chunk_structure: Mapped[str] = mapped_column(sa.String(255), nullable=False) - icon: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + icon: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) yaml_content: Mapped[str] = mapped_column(LongText, nullable=False) install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False) @@ -1602,7 +1671,7 @@ class DocumentPipelineExecutionLog(TypeBase): datasource_type: Mapped[str] = mapped_column(sa.String(255), nullable=False) datasource_info: Mapped[str] = mapped_column(LongText, nullable=False) datasource_node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) - input_data: Mapped[dict] = mapped_column(sa.JSON, nullable=False) + input_data: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1633,7 +1702,7 @@ class PipelineRecommendedPlugin(TypeBase): ) -class SegmentAttachmentBinding(Base): +class SegmentAttachmentBinding(TypeBase): __tablename__ = "segment_attachment_bindings" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"), @@ -1646,16 +1715,20 @@ class SegmentAttachmentBinding(Base): ), sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False + ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) -class DocumentSegmentSummary(Base): +class DocumentSegmentSummary(TypeBase): __tablename__ = "document_segment_summaries" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="document_segment_summaries_pkey"), @@ -1665,25 +1738,40 @@ class DocumentSegmentSummary(Base): sa.Index("document_segment_summaries_status_idx", "status"), ) - id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column( + StringUUID, + nullable=False, + insert_default=lambda: str(uuid4()), + default_factory=lambda: str(uuid4()), + init=False, + ) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) # corresponds to DocumentSegment.id or parent chunk id chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - summary_content: Mapped[str] = mapped_column(LongText, nullable=True) - summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) - summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) - tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - status: Mapped[str] = mapped_column( - EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'") + summary_content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + summary_index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + summary_index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) + tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True, default=None) + status: Mapped[SummaryStatus] = mapped_column( + EnumText(SummaryStatus, length=32), + nullable=False, + server_default=sa.text("'generating'"), + default=SummaryStatus.GENERATING, + ) + error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) + enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True) + disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) + disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) - error: Mapped[str] = mapped_column(LongText, nullable=True) - enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) - disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) - disabled_by = mapped_column(StringUUID, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) def __repr__(self): diff --git a/api/models/human_input.py b/api/models/human_input.py index 79c5d62f6a..7447d3efcb 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,11 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from core.workflow.human_input_compat import DeliveryMethodType +from core.workflow.human_input_adapter import DeliveryMethodType +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index 0ea2259a19..25c330b062 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -14,9 +14,6 @@ from uuid import uuid4 import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] -from graphon.enums import WorkflowExecutionStatus -from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker @@ -24,7 +21,11 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file from extensions.storage.storage_type import StorageType +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from libs.helper import generate_string # type: ignore[import-not-found] +from libs.url_utils import normalize_api_base_url from libs.uuid_utils import uuidv7 from models.utils.file_input_compat import build_file_from_input_mapping @@ -90,6 +91,19 @@ class EnabledConfig(TypedDict): enabled: bool +class SuggestedQuestionsAfterAnswerModelConfig(TypedDict): + provider: str + name: str + mode: NotRequired[str] + completion_params: NotRequired[dict[str, Any]] + + +class SuggestedQuestionsAfterAnswerConfig(TypedDict): + enabled: bool + model: NotRequired[SuggestedQuestionsAfterAnswerModelConfig] + prompt: NotRequired[str] + + class EmbeddingModelInfo(TypedDict): embedding_provider_name: str embedding_model_name: str @@ -219,7 +233,7 @@ class ModelConfig(TypedDict): class AppModelConfigDict(TypedDict): opening_statement: str | None suggested_questions: list[str] - suggested_questions_after_answer: EnabledConfig + suggested_questions_after_answer: SuggestedQuestionsAfterAnswerConfig speech_to_text: EnabledConfig text_to_speech: EnabledConfig retriever_resource: EnabledConfig @@ -446,7 +460,8 @@ class App(Base): @property def api_base_url(self) -> str: - return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" + base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/") + return normalize_api_base_url(base) @property def tenant(self) -> Tenant | None: @@ -678,8 +693,13 @@ class AppModelConfig(TypeBase): return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled}) @property - def suggested_questions_after_answer_dict(self) -> EnabledConfig: - return self._get_enabled_config(self.suggested_questions_after_answer) + def suggested_questions_after_answer_dict(self) -> SuggestedQuestionsAfterAnswerConfig: + return cast( + SuggestedQuestionsAfterAnswerConfig, + json.loads(self.suggested_questions_after_answer) + if self.suggested_questions_after_answer + else {"enabled": False}, + ) @property def speech_to_text_dict(self) -> EnabledConfig: @@ -838,7 +858,7 @@ class AppModelConfig(TypeBase): return self -class RecommendedApp(Base): # bug +class RecommendedApp(TypeBase): __tablename__ = "recommended_apps" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"), @@ -846,20 +866,37 @@ class RecommendedApp(Base): # bug sa.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4())) - app_id = mapped_column(StringUUID, nullable=False) - description = mapped_column(sa.JSON, nullable=False) + id: Mapped[str] = mapped_column( + StringUUID, + primary_key=True, + insert_default=lambda: str(uuid4()), + default_factory=lambda: str(uuid4()), + init=False, + ) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + description: Mapped[Any] = mapped_column(sa.JSON, nullable=False) copyright: Mapped[str] = mapped_column(String(255), nullable=False) privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False) - custom_disclaimer: Mapped[str] = mapped_column(LongText, default="") category: Mapped[str] = mapped_column(String(255), nullable=False) + custom_disclaimer: Mapped[str] = mapped_column(LongText, default="") position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) - language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'")) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + language: Mapped[str] = mapped_column( + String(255), + nullable=False, + server_default=sa.text("'en-US'"), + default="en-US", + ) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) @property @@ -990,7 +1027,7 @@ class OAuthProviderApp(TypeBase): app_icon: Mapped[str] = mapped_column(String(255), nullable=False) client_id: Mapped[str] = mapped_column(String(255), nullable=False) client_secret: Mapped[str] = mapped_column(String(255), nullable=False) - app_label: Mapped[dict] = mapped_column(sa.JSON, nullable=False, default_factory=dict) + app_label: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, default_factory=dict) redirect_uris: Mapped[list] = mapped_column(sa.JSON, nullable=False, default_factory=list) scope: Mapped[str] = mapped_column( String(255), @@ -1061,7 +1098,7 @@ class Conversation(Base): messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( - "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" + lambda: MessageAnnotation, backref="conversation", lazy="select", passive_deletes="all" ) is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) @@ -1820,7 +1857,7 @@ class MessageFile(TypeBase): ) -class MessageAnnotation(Base): +class MessageAnnotation(TypeBase): __tablename__ = "message_annotations" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"), @@ -1829,17 +1866,28 @@ class MessageAnnotation(Base): sa.Index("message_annotation_message_idx", "message_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + id: Mapped[str] = mapped_column( + StringUUID, + insert_default=lambda: str(uuid4()), + default_factory=lambda: str(uuid4()), + init=False, + ) app_id: Mapped[str] = mapped_column(StringUUID) - conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id")) - message_id: Mapped[str | None] = mapped_column(StringUUID) question: Mapped[str] = mapped_column(LongText, nullable=False) content: Mapped[str] = mapped_column(LongText, nullable=False) - hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) + hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), init=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None) + message_id: Mapped[str | None] = mapped_column(StringUUID, default=None) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) @property @@ -2134,7 +2182,7 @@ class ApiToken(Base): # bug: this uses setattr so idk the field. return result -class UploadFile(Base): +class UploadFile(TypeBase): __tablename__ = "upload_files" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="upload_file_pkey"), @@ -2142,9 +2190,12 @@ class UploadFile(Base): ) # NOTE: The `id` field is generated within the application to minimize extra roundtrips - # (especially when generating `source_url`). - # The `server_default` serves as a fallback mechanism. - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) + # (especially when generating `source_url`) and keep model metadata portable across databases. + id: Mapped[str] = mapped_column( + StringUUID, + init=False, + default_factory=lambda: str(uuid4()), + ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False) key: Mapped[str] = mapped_column(String(255), nullable=False) @@ -2152,16 +2203,6 @@ class UploadFile(Base): size: Mapped[int] = mapped_column(sa.Integer, nullable=False) extension: Mapped[str] = mapped_column(String(255), nullable=False) mime_type: Mapped[str] = mapped_column(String(255), nullable=True) - - # The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`. - # Its value is derived from the `CreatorUserRole` enumeration. - created_by_role: Mapped[CreatorUserRole] = mapped_column( - EnumText(CreatorUserRole, length=255), - nullable=False, - server_default=sa.text("'account'"), - default=CreatorUserRole.ACCOUNT, - ) - # The `created_by` field stores the ID of the entity that created this upload file. # # If `created_by_role` is `ACCOUNT`, it corresponds to `Account.id`. @@ -2180,10 +2221,18 @@ class UploadFile(Base): # `used` may indicate whether the file has been utilized by another service. used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) + # The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`. + # Its value is derived from the `CreatorUserRole` enumeration. + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), + nullable=False, + server_default=sa.text("'account'"), + default=CreatorUserRole.ACCOUNT, + ) # `used_by` may indicate the ID of the user who utilized this file. - used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True) - hash: Mapped[str | None] = mapped_column(String(255), nullable=True) + used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) + used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True, default=None) + hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) source_url: Mapped[str] = mapped_column(LongText, default="") def __init__( @@ -2470,7 +2519,7 @@ class TraceAppConfig(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tracing_provider: Mapped[str | None] = mapped_column(String(255), nullable=True) - tracing_config: Mapped[dict | None] = mapped_column(sa.JSON, nullable=True) + tracing_config: Mapped[dict[str, Any] | None] = mapped_column(sa.JSON, nullable=True) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) diff --git a/api/models/oauth.py b/api/models/oauth.py index 1db2552469..bd04d890d3 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any import sqlalchemy as sa from sqlalchemy import func @@ -22,7 +23,7 @@ class DatasourceOauthParamConfig(TypeBase): ) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) provider: Mapped[str] = mapped_column(sa.String(255), nullable=False) - system_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + system_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) class DatasourceProvider(TypeBase): @@ -40,7 +41,7 @@ class DatasourceProvider(TypeBase): provider: Mapped[str] = mapped_column(sa.String(128), nullable=False) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) auth_type: Mapped[str] = mapped_column(sa.String(255), nullable=False) - encrypted_credentials: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + encrypted_credentials: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) avatar_url: Mapped[str] = mapped_column(LongText, nullable=True, default="default") is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) expires_at: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default="-1", default=-1) @@ -70,7 +71,7 @@ class DatasourceOauthTenantParamConfig(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider: Mapped[str] = mapped_column(sa.String(255), nullable=False) plugin_id: Mapped[str] = mapped_column(sa.String(255), nullable=False) - client_params: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict) + client_params: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False, default_factory=dict) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) created_at: Mapped[datetime] = mapped_column( diff --git a/api/models/provider.py b/api/models/provider.py index 8270961b31..8dc3ce4ff6 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,14 +6,14 @@ from functools import cached_property from uuid import uuid4 import sqlalchemy as sa -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import DateTime, String, func, select, text from sqlalchemy.orm import Mapped, mapped_column +from core.db.session_factory import session_factory +from graphon.model_runtime.entities.model_entities import ModelType from libs.uuid_utils import uuidv7 from .base import TypeBase -from .engine import db from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType from .types import EnumText, LongText, StringUUID @@ -82,7 +82,8 @@ class Provider(TypeBase): @cached_property def credential(self): if self.credential_id: - return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id)) + with session_factory.create_session() as session: + return session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id)) @property def credential_name(self): @@ -145,9 +146,10 @@ class ProviderModel(TypeBase): @cached_property def credential(self): if self.credential_id: - return db.session.scalar( - select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id) - ) + with session_factory.create_session() as session: + return session.scalar( + select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id) + ) @property def credential_name(self): diff --git a/api/models/source.py b/api/models/source.py index a8addbe342..8fce7df205 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,5 +1,6 @@ import json from datetime import datetime +from typing import Any, TypedDict from uuid import uuid4 import sqlalchemy as sa @@ -24,7 +25,7 @@ class DataSourceOauthBinding(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) access_token: Mapped[str] = mapped_column(String(255), nullable=False) provider: Mapped[str] = mapped_column(String(255), nullable=False) - source_info: Mapped[dict] = mapped_column(AdjustedJSON, nullable=False) + source_info: Mapped[dict[str, Any]] = mapped_column(AdjustedJSON, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -38,6 +39,17 @@ class DataSourceOauthBinding(TypeBase): disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False) +class DataSourceApiKeyAuthBindingDict(TypedDict): + id: str + tenant_id: str + category: str + provider: str + credentials: Any + created_at: float + updated_at: float + disabled: bool + + class DataSourceApiKeyAuthBinding(TypeBase): __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( @@ -65,8 +77,8 @@ class DataSourceApiKeyAuthBinding(TypeBase): ) disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False) - def to_dict(self): - return { + def to_dict(self) -> DataSourceApiKeyAuthBindingDict: + result: DataSourceApiKeyAuthBindingDict = { "id": self.id, "tenant_id": self.tenant_id, "category": self.category, @@ -76,3 +88,4 @@ class DataSourceApiKeyAuthBinding(TypeBase): "updated_at": self.updated_at.timestamp(), "disabled": self.disabled, } + return result diff --git a/api/models/types.py b/api/models/types.py index c1d9c3845a..4f35c31a27 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -103,10 +103,14 @@ class AdjustedJSON(TypeDecorator[dict | list | None]): else: return dialect.type_descriptor(sa.JSON()) - def process_bind_param(self, value: dict | list | None, dialect: Dialect) -> dict | list | None: + def process_bind_param( + self, value: dict[str, Any] | list[Any] | None, dialect: Dialect + ) -> dict[str, Any] | list[Any] | None: return value - def process_result_value(self, value: dict | list | None, dialect: Dialect) -> dict | list | None: + def process_result_value( + self, value: dict[str, Any] | list[Any] | None, dialect: Dialect + ) -> dict[str, Any] | list[Any] | None: return value diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index 8b767779ce..77dcbd13d4 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -4,9 +4,9 @@ from collections.abc import Callable, Mapping from functools import lru_cache from typing import Any -from graphon.file import File, FileTransferMethod - from core.workflow.file_reference import parse_file_reference +from graphon.file import File, FileTransferMethod, FileType +from graphon.file.constants import FILE_MODEL_IDENTITY, maybe_file_object @lru_cache(maxsize=1) @@ -44,6 +44,124 @@ def resolve_file_mapping_tenant_id( return tenant_resolver() +def build_file_from_mapping_without_lookup(*, file_mapping: Mapping[str, Any]) -> File: + """Build a graph `File` directly from serialized metadata.""" + + def _coerce_file_type(value: Any) -> FileType: + if isinstance(value, FileType): + return value + if isinstance(value, str): + return FileType.value_of(value) + raise ValueError("file type is required in file mapping") + + mapping = dict(file_mapping) + transfer_method_value = mapping.get("transfer_method") + if isinstance(transfer_method_value, FileTransferMethod): + transfer_method = transfer_method_value + elif isinstance(transfer_method_value, str): + transfer_method = FileTransferMethod.value_of(transfer_method_value) + else: + raise ValueError("transfer_method is required in file mapping") + + file_id = mapping.get("file_id") + if not isinstance(file_id, str) or not file_id: + legacy_id = mapping.get("id") + file_id = legacy_id if isinstance(legacy_id, str) and legacy_id else None + + related_id = resolve_file_record_id(mapping) + if related_id is None: + raw_related_id = mapping.get("related_id") + related_id = raw_related_id if isinstance(raw_related_id, str) and raw_related_id else None + + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + remote_url = url if isinstance(url, str) and url else None + + reference = mapping.get("reference") + if not isinstance(reference, str) or not reference: + reference = None + + filename = mapping.get("filename") + if not isinstance(filename, str): + filename = None + + extension = mapping.get("extension") + if not isinstance(extension, str): + extension = None + + mime_type = mapping.get("mime_type") + if not isinstance(mime_type, str): + mime_type = None + + size = mapping.get("size", -1) + if not isinstance(size, int): + size = -1 + + storage_key = mapping.get("storage_key") + if not isinstance(storage_key, str): + storage_key = None + + tenant_id = mapping.get("tenant_id") + if not isinstance(tenant_id, str): + tenant_id = None + + dify_model_identity = mapping.get("dify_model_identity") + if not isinstance(dify_model_identity, str): + dify_model_identity = FILE_MODEL_IDENTITY + + tool_file_id = mapping.get("tool_file_id") + if not isinstance(tool_file_id, str): + tool_file_id = None + + upload_file_id = mapping.get("upload_file_id") + if not isinstance(upload_file_id, str): + upload_file_id = None + + datasource_file_id = mapping.get("datasource_file_id") + if not isinstance(datasource_file_id, str): + datasource_file_id = None + + return File( + file_id=file_id, + tenant_id=tenant_id, + file_type=_coerce_file_type(mapping.get("file_type", mapping.get("type"))), + transfer_method=transfer_method, + remote_url=remote_url, + reference=reference, + related_id=related_id, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + storage_key=storage_key, + dify_model_identity=dify_model_identity, + url=remote_url, + tool_file_id=tool_file_id, + upload_file_id=upload_file_id, + datasource_file_id=datasource_file_id, + ) + + +def rebuild_serialized_graph_files_without_lookup(value: Any) -> Any: + """Recursively rebuild serialized graph file payloads into `File` objects. + + `graphon` 0.2.2 no longer accepts legacy serialized file mappings via + `model_validate_json()`. Dify keeps this recovery path at the model boundary + so historical JSON blobs remain readable without reintroducing global graph + patches or test-local coercion. + """ + if isinstance(value, list): + return [rebuild_serialized_graph_files_without_lookup(item) for item in value] + + if isinstance(value, dict): + if maybe_file_object(value): + return build_file_from_mapping_without_lookup(file_mapping=value) + return {key: rebuild_serialized_graph_files_without_lookup(item) for key, item in value.items()} + + return value + + def build_file_from_stored_mapping( *, file_mapping: Mapping[str, Any], @@ -77,12 +195,7 @@ def build_file_from_stored_mapping( pass if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: - remote_url = mapping.get("remote_url") - if not isinstance(remote_url, str) or not remote_url: - url = mapping.get("url") - if isinstance(url, str) and url: - mapping["remote_url"] = url - return File.model_validate(mapping) + return build_file_from_mapping_without_lookup(file_mapping=mapping) return file_factory.build_from_mapping( mapping=mapping, diff --git a/api/models/workflow.py b/api/models/workflow.py index bb4d6a7ec9..7936c06a5a 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -8,19 +8,6 @@ from typing import TYPE_CHECKING, Any, Optional, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import ( - BuiltinNodeTypes, - NodeType, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from graphon.file import File -from graphon.file.constants import maybe_file_object -from graphon.variables import utils as variable_utils -from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from sqlalchemy import ( DateTime, Index, @@ -37,35 +24,54 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.human_input_adapter import adapt_node_config_for_graph from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File +from graphon.file.constants import maybe_file_object +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from ._workflow_exc import NodeNotFoundError, WorkflowDataError if TYPE_CHECKING: - from .model import AppMode, UploadFile + from .model import AppMode -from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase - from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter from factories import variable_factory +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase from libs import helper from .account import Account -from .base import Base, DefaultFieldsMixin, TypeBase +from .base import Base, DefaultFieldsDCMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom + +# UploadFile uses TypeBase while workflow execution offload models use Base, so relationships +# must target the class object directly instead of relying on string lookup across registries. +from .model import UploadFile from .types import EnumText, LongText, StringUUID -from .utils.file_input_compat import build_file_from_stored_mapping +from .utils.file_input_compat import ( + build_file_from_mapping_without_lookup, + build_file_from_stored_mapping, +) logger = logging.getLogger(__name__) @@ -291,7 +297,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) + return NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -490,7 +496,7 @@ class Workflow(Base): # bug :return: hash """ - entity = {"graph": self.graph_dict, "features": self.features_dict} + entity = {"graph": self.graph_dict} return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @@ -671,6 +677,29 @@ class Workflow(Base): # bug return str(d) +class WorkflowRunDict(TypedDict): + id: str + tenant_id: str + app_id: str + workflow_id: str + type: WorkflowType + triggered_from: WorkflowRunTriggeredFrom + version: str + graph: Mapping[str, Any] + inputs: Mapping[str, Any] + status: WorkflowExecutionStatus + outputs: Mapping[str, Any] + error: str | None + elapsed_time: float + total_tokens: int + total_steps: int + created_by_role: CreatorUserRole + created_by: str + created_at: datetime + finished_at: datetime | None + exceptions_count: int + + class WorkflowRun(Base): """ Workflow Run @@ -742,8 +771,8 @@ class WorkflowRun(Base): exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) pause: Mapped[Optional["WorkflowPause"]] = orm.relationship( - "WorkflowPause", - primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", + lambda: WorkflowPause, + primaryjoin=lambda: WorkflowRun.id == orm.foreign(WorkflowPause.workflow_run_id), uselist=False, # require explicit preloading. lazy="raise", @@ -790,29 +819,29 @@ class WorkflowRun(Base): def workflow(self): return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id)) - def to_dict(self): - return { - "id": self.id, - "tenant_id": self.tenant_id, - "app_id": self.app_id, - "workflow_id": self.workflow_id, - "type": self.type, - "triggered_from": self.triggered_from, - "version": self.version, - "graph": self.graph_dict, - "inputs": self.inputs_dict, - "status": self.status, - "outputs": self.outputs_dict, - "error": self.error, - "elapsed_time": self.elapsed_time, - "total_tokens": self.total_tokens, - "total_steps": self.total_steps, - "created_by_role": self.created_by_role, - "created_by": self.created_by, - "created_at": self.created_at, - "finished_at": self.finished_at, - "exceptions_count": self.exceptions_count, - } + def to_dict(self) -> WorkflowRunDict: + return WorkflowRunDict( + id=self.id, + tenant_id=self.tenant_id, + app_id=self.app_id, + workflow_id=self.workflow_id, + type=self.type, + triggered_from=self.triggered_from, + version=self.version, + graph=self.graph_dict, + inputs=self.inputs_dict, + status=self.status, + outputs=self.outputs_dict, + error=self.error, + elapsed_time=self.elapsed_time, + total_tokens=self.total_tokens, + total_steps=self.total_steps, + created_by_role=self.created_by_role, + created_by=self.created_by, + created_at=self.created_at, + finished_at=self.finished_at, + exceptions_count=self.exceptions_count, + ) @classmethod def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": @@ -1071,8 +1100,6 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo @staticmethod def _load_full_content(session: orm.Session, file_id: str, storage: Storage): - from .model import UploadFile - stmt = sa.select(UploadFile).where(UploadFile.id == file_id) file = session.scalars(stmt).first() assert file is not None, f"UploadFile with id {file_id} should exist but not" @@ -1166,10 +1193,11 @@ class WorkflowNodeExecutionOffload(Base): ) file: Mapped[Optional["UploadFile"]] = orm.relationship( + UploadFile, foreign_keys=[file_id], lazy="raise", uselist=False, - primaryjoin="WorkflowNodeExecutionOffload.file_id == UploadFile.id", + primaryjoin=lambda: orm.foreign(WorkflowNodeExecutionOffload.file_id) == UploadFile.id, ) @@ -1196,6 +1224,18 @@ class WorkflowAppLogCreatedFrom(StrEnum): raise ValueError(f"invalid workflow app log created from value {value}") +class WorkflowAppLogDict(TypedDict): + id: str + tenant_id: str + app_id: str + workflow_id: str + workflow_run_id: str + created_from: WorkflowAppLogCreatedFrom + created_by_role: CreatorUserRole + created_by: str + created_at: datetime + + class WorkflowAppLog(TypeBase): """ Workflow App execution log, excluding workflow debugging records. @@ -1273,8 +1313,8 @@ class WorkflowAppLog(TypeBase): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None - def to_dict(self): - return { + def to_dict(self) -> WorkflowAppLogDict: + result: WorkflowAppLogDict = { "id": self.id, "tenant_id": self.tenant_id, "app_id": self.app_id, @@ -1285,6 +1325,7 @@ class WorkflowAppLog(TypeBase): "created_by": self.created_by, "created_at": self.created_at, } + return result class WorkflowArchiveLog(TypeBase): @@ -1527,12 +1568,14 @@ class WorkflowDraftVariable(Base): ), ) - # Relationship to WorkflowDraftVariableFile + # WorkflowDraftVariableFile uses TypeBase while WorkflowDraftVariable uses Base, so the relationship + # must resolve the class object lazily instead of relying on string lookup across registries. variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship( + lambda: WorkflowDraftVariableFile, foreign_keys=[file_id], lazy="raise", uselist=False, - primaryjoin="WorkflowDraftVariableFile.id == WorkflowDraftVariable.file_id", + primaryjoin=lambda: orm.foreign(WorkflowDraftVariable.file_id) == WorkflowDraftVariableFile.id, ) # Cache for deserialized value @@ -1653,7 +1696,7 @@ class WorkflowDraftVariable(Base): return cast(Any, value) normalized_file = dict(value) normalized_file.pop("tenant_id", None) - return File.model_validate(normalized_file) + return build_file_from_mapping_without_lookup(file_mapping=normalized_file) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] @@ -1663,7 +1706,7 @@ class WorkflowDraftVariable(Base): for item in value_list: normalized_file = dict(cast(dict[str, Any], item)) normalized_file.pop("tenant_id", None) - file_list.append(File.model_validate(normalized_file)) + file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file)) return cast(Any, file_list) else: return cast(Any, value) @@ -1851,7 +1894,7 @@ class WorkflowDraftVariable(Base): return self.last_edited_at is not None -class WorkflowDraftVariableFile(Base): +class WorkflowDraftVariableFile(TypeBase): """Stores metadata about files associated with large workflow draft variables. This model acts as an intermediary between WorkflowDraftVariable and UploadFile, @@ -1865,18 +1908,7 @@ class WorkflowDraftVariableFile(Base): __tablename__ = "workflow_draft_variable_files" # Primary key - id: Mapped[str] = mapped_column( - StringUUID, - primary_key=True, - default=lambda: str(uuidv7()), - ) - - created_at: Mapped[datetime] = mapped_column( - DateTime, - nullable=False, - default=naive_utc_now, - server_default=func.current_timestamp(), - ) + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default_factory=lambda: str(uuidv7()), init=False) tenant_id: Mapped[str] = mapped_column( StringUUID, @@ -1928,12 +1960,21 @@ class WorkflowDraftVariableFile(Base): nullable=False, ) - # Relationship to UploadFile + # Rows are created with `upload_file_id`; callers should load this relationship explicitly when needed. upload_file: Mapped["UploadFile"] = orm.relationship( + UploadFile, foreign_keys=[upload_file_id], lazy="raise", + init=False, uselist=False, - primaryjoin="WorkflowDraftVariableFile.upload_file_id == UploadFile.id", + primaryjoin=lambda: orm.foreign(WorkflowDraftVariableFile.upload_file_id) == UploadFile.id, + ) + + created_at: Mapped[datetime] = mapped_column( + DateTime, + nullable=False, + default_factory=naive_utc_now, + server_default=func.current_timestamp(), ) @@ -1941,7 +1982,7 @@ def is_system_variable_editable(name: str) -> bool: return name in _EDITABLE_SYSTEM_VARIABLE -class WorkflowPause(DefaultFieldsMixin, Base): +class WorkflowPause(DefaultFieldsDCMixin, TypeBase): """ WorkflowPause records the paused state and related metadata for a specific workflow run. @@ -1980,6 +2021,11 @@ class WorkflowPause(DefaultFieldsMixin, Base): nullable=False, ) + # state_object_key stores the object key referencing the serialized runtime state + # of the `GraphEngine`. This object captures the complete execution context of the + # workflow at the moment it was paused, enabling accurate resumption. + state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) + # `resumed_at` records the timestamp when the suspended workflow was resumed. # It is set to `NULL` if the workflow has not been resumed. # @@ -1988,25 +2034,23 @@ class WorkflowPause(DefaultFieldsMixin, Base): resumed_at: Mapped[datetime | None] = mapped_column( sa.DateTime, nullable=True, + default=None, ) - # state_object_key stores the object key referencing the serialized runtime state - # of the `GraphEngine`. This object captures the complete execution context of the - # workflow at the moment it was paused, enabling accurate resumption. - state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) - - # Relationship to WorkflowRun + # Relationship to WorkflowRun (uses lambda to resolve across Base/TypeBase registries) workflow_run: Mapped["WorkflowRun"] = orm.relationship( + lambda: WorkflowRun, foreign_keys=[workflow_run_id], # require explicit preloading. lazy="raise", uselist=False, - primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id", + primaryjoin=lambda: WorkflowPause.workflow_run_id == WorkflowRun.id, back_populates="pause", + init=False, ) -class WorkflowPauseReason(DefaultFieldsMixin, Base): +class WorkflowPauseReason(DefaultFieldsDCMixin, TypeBase): __tablename__ = "workflow_pause_reasons" # `pause_id` represents the identifier of the pause, @@ -2049,16 +2093,20 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base): lazy="raise", uselist=False, primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id", + init=False, ) @classmethod - def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason": + def from_entity(cls, *, pause_id: str, pause_reason: PauseReason) -> "WorkflowPauseReason": if isinstance(pause_reason, HumanInputRequired): return cls( - type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id + pause_id=pause_id, + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id=pause_reason.form_id, + node_id=pause_reason.node_id, ) elif isinstance(pause_reason, SchedulingPause): - return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="") + return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message) else: raise AssertionError(f"Unknown pause reason type: {pause_reason}") diff --git a/api/providers/README.md b/api/providers/README.md new file mode 100644 index 0000000000..5d5e6db9af --- /dev/null +++ b/api/providers/README.md @@ -0,0 +1,15 @@ +# Providers + +This directory holds **optional workspace packages** that plug into Dify’s API core. Providers are responsible for implementing the interfaces and registering themselves to the API core. Provider mechanism allows building the software with selected set of providers so as to enhance the security and flexibility of distributions. + +## Developing Providers + +- [VDB Providers](vdb/README.md) + +## Tests + +Provider tests often live next to the package, e.g. `providers///tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`). + +## Excluding Providers + +In order to build with selected providers, use `--no-group vdb-all` and `--no-group trace-all` to disable default ones, then use `--group vdb-` and `--group trace-` to enable specific providers. diff --git a/api/providers/trace/README.md b/api/providers/trace/README.md new file mode 100644 index 0000000000..a7ffa5ed26 --- /dev/null +++ b/api/providers/trace/README.md @@ -0,0 +1,78 @@ +# Trace providers + +This directory holds **optional workspace packages** that send Dify **ops tracing** data (workflows, messages, tools, moderation, etc.) to an external observability backend (Langfuse, LangSmith, OpenTelemetry-style exporters, and others). + +Unlike VDB providers, trace plugins are **not** discovered via entry points. The API core imports your package **explicitly** from `core/ops/ops_trace_manager.py` after you register the provider id and mapping. + +## Architecture + +| Layer | Location | Role | +|--------|----------|------| +| Contracts | `api/core/ops/base_trace_instance.py`, `api/core/ops/entities/trace_entity.py`, `api/core/ops/entities/config_entity.py` | `BaseTraceInstance`, `BaseTracingConfig`, and typed `*TraceInfo` payloads | +| Registry | `api/core/ops/ops_trace_manager.py` | `TracingProviderEnum`, `OpsTraceProviderConfigMap` — maps provider **string** → config class, encrypted keys, and trace class | +| Your package | `api/providers/trace/trace-/` | Pydantic config + subclass of `BaseTraceInstance` | + +At runtime, `OpsTraceManager` decrypts stored credentials, builds your config model, caches a trace instance, and calls `trace(trace_info)` with a concrete `BaseTraceInfo` subtype. + +## What you implement + +### 1. Config model (`BaseTracingConfig`) + +Subclass `BaseTracingConfig` from `core.ops.entities.config_entity`. Use Pydantic validators; reuse helpers from `core.ops.utils` (for example `validate_url`, `validate_url_with_path`, `validate_project_name`) where appropriate. + +Fields fall into two groups used by the manager: + +- **`secret_keys`** — names of fields that are **encrypted at rest** (API keys, tokens, passwords). +- **`other_keys`** — non-secret connection settings (hosts, project names, endpoints). + +List these key names in your `OpsTraceProviderConfigMap` entry so encrypt/decrypt and merge logic stay correct. + +### 2. Trace instance (`BaseTraceInstance`) + +Subclass `BaseTraceInstance` and implement: + +```python +def trace(self, trace_info: BaseTraceInfo) -> None: + ... +``` + +Dispatch on the concrete type with `isinstance` (see `trace_langfuse` or `trace_langsmith` for full patterns). Payload types are defined in `core/ops/entities/trace_entity.py`, including: + +- `WorkflowTraceInfo`, `WorkflowNodeTraceInfo`, `DraftNodeExecutionTrace` +- `MessageTraceInfo`, `ToolTraceInfo`, `ModerationTraceInfo`, `SuggestedQuestionTraceInfo` +- `DatasetRetrievalTraceInfo`, `GenerateNameTraceInfo`, `PromptGenerationTraceInfo` + +You may ignore categories your backend does not support; existing providers often no-op unhandled types. + +Optional: use `get_service_account_with_tenant(app_id)` from the base class when you need tenant-scoped account context. + +### 3. Register in the API core + +Upstream changes are required so Dify knows your provider exists: + +1. **`TracingProviderEnum`** (`api/core/ops/entities/config_entity.py`) — add a new member whose **value** is the stable string stored in app tracing config (e.g. `"mybackend"`). +2. **`OpsTraceProviderConfigMap.__getitem__`** (`api/core/ops/ops_trace_manager.py`) — add a `match` case for that enum member returning: + - `config_class`: your Pydantic config type + - `secret_keys` / `other_keys`: lists of field names as above + - `trace_instance`: your `BaseTraceInstance` subclass + Lazy-import your package inside the case so missing optional installs raise a clear `ImportError`. + +If the `match` case is missing, the provider string will not resolve and tracing will be disabled for that app. + +## Package layout + +Each provider is a normal uv workspace member, for example: + +- `api/providers/trace/trace-/pyproject.toml` — project name `dify-trace-`, dependencies on vendor SDKs +- `api/providers/trace/trace-/src/dify_trace_/` — `config.py`, `_trace.py`, optional `entities/`, and an empty **`py.typed`** file (PEP 561) so the API type checker treats the package as typed; list `py.typed` under `[tool.setuptools.package-data]` for that import name in `pyproject.toml`. + +Reference implementations: `trace-langfuse/`, `trace-langsmith/`, `trace-opik/`. + +## Wiring into the `api` workspace + +In `api/pyproject.toml`: + +1. **`[tool.uv.sources]`** — `dify-trace- = { workspace = true }` +2. **`[dependency-groups]`** — add `trace- = ["dify-trace-"]` and include `dify-trace-` in `trace-all` if it should ship with the default bundle + +After changing metadata, run **`uv sync`** from `api/`. diff --git a/api/providers/trace/trace-aliyun/pyproject.toml b/api/providers/trace/trace-aliyun/pyproject.toml new file mode 100644 index 0000000000..bcef7e9fb1 --- /dev/null +++ b/api/providers/trace/trace-aliyun/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-aliyun" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Aliyun)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/aliyun_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/__init__.py diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py similarity index 98% rename from api/core/ops/aliyun_trace/aliyun_trace.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py index 70aaf2a07b..54d2f8167f 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/aliyun_trace.py @@ -1,12 +1,23 @@ import logging from collections.abc import Sequence -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.data_exporter.traceclient import ( TraceClient, build_endpoint, convert_datetime_to_nanoseconds, @@ -14,8 +25,8 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( convert_to_trace_id, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata +from dify_trace_aliyun.entities.semconv import ( DIFY_APP_ID, GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, @@ -34,7 +45,7 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -46,20 +57,9 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) -from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import AliyunConfig -from core.ops.entities.trace_entity import ( - BaseTraceInfo, - DatasetRetrievalTraceInfo, - GenerateNameTraceInfo, - MessageTraceInfo, - ModerationTraceInfo, - SuggestedQuestionTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py new file mode 100644 index 0000000000..e0133e6cc9 --- /dev/null +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/config.py @@ -0,0 +1,32 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class AliyunConfig(BaseTracingConfig): + """ + Model class for Aliyun tracing config. + """ + + app_name: str = "dify_app" + license_key: str + endpoint: str + + @field_validator("app_name") + @classmethod + def app_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") + + @field_validator("license_key") + @classmethod + def license_key_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("License key cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # aliyun uses two URL formats, which may include a URL path + return validate_url_with_path(v, "https://tracing-analysis-dc-hz.aliyuncs.com") diff --git a/api/core/ops/aliyun_trace/data_exporter/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/data_exporter/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/__init__.py diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py similarity index 98% rename from api/core/ops/aliyun_trace/data_exporter/traceclient.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py index 67d5163b0f..00aab6bf89 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py @@ -26,8 +26,8 @@ from opentelemetry.semconv.attributes import service_attributes from opentelemetry.trace import Link, SpanContext, TraceFlags from configs import dify_config -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData -from core.ops.aliyun_trace.entities.semconv import ACS_ARMS_SERVICE_FEATURE +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.semconv import ACS_ARMS_SERVICE_FEATURE INVALID_SPAN_ID: Final[int] = 0x0000000000000000 INVALID_TRACE_ID: Final[int] = 0x00000000000000000000000000000000 diff --git a/api/core/ops/aliyun_trace/entities/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/__init__.py diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/aliyun_trace_entity.py diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py similarity index 100% rename from api/core/ops/aliyun_trace/entities/semconv.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/entities/semconv.py diff --git a/api/core/ops/arize_phoenix_trace/__init__.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed similarity index 100% rename from api/core/ops/arize_phoenix_trace/__init__.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/py.typed diff --git a/api/core/ops/aliyun_trace/utils.py b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py similarity index 97% rename from api/core/ops/aliyun_trace/utils.py rename to api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py index aa35ac74c2..5678c66adb 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/providers/trace/trace-aliyun/src/dify_trace_aliyun/utils.py @@ -2,11 +2,10 @@ import json from collections.abc import Mapping from typing import Any, TypedDict -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode -from core.ops.aliyun_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -15,8 +14,9 @@ from core.ops.aliyun_trace.entities.semconv import ( OUTPUT_VALUE, GenAISpanKind, ) -from core.rag.models.document import Document from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser # Constants @@ -48,7 +48,7 @@ def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status: def create_links_from_trace_id(trace_id: str | None) -> list[Link]: - from core.ops.aliyun_trace.data_exporter.traceclient import create_link + from dify_trace_aliyun.data_exporter.traceclient import create_link links = [] if trace_id: diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py similarity index 86% rename from api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py index acb43d4036..ac09060e9d 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/data_exporter/test_traceclient.py @@ -5,10 +5,7 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from opentelemetry.sdk.trace import ReadableSpan -from opentelemetry.trace import SpanKind, Status, StatusCode - -from core.ops.aliyun_trace.data_exporter.traceclient import ( +from dify_trace_aliyun.data_exporter.traceclient import ( INVALID_SPAN_ID, SpanBuilder, TraceClient, @@ -20,7 +17,9 @@ from core.ops.aliyun_trace.data_exporter.traceclient import ( create_link, generate_span_id, ) -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.trace import SpanKind, Status, StatusCode @pytest.fixture @@ -41,8 +40,8 @@ def trace_client_factory(): class TestTraceClient: - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.socket.gethostname") def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory): mock_gethostname.return_value = "test-host" client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -56,7 +55,7 @@ class TestTraceClient: client.shutdown() assert client.done is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -64,8 +63,8 @@ class TestTraceClient: client.export(spans) mock_exporter.export.assert_called_once_with(spans) - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 405 @@ -74,8 +73,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is True - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory): mock_response = MagicMock() mock_response.status_code = 500 @@ -84,8 +83,8 @@ class TestTraceClient: client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.api_check() is False - @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.httpx.head") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory): mock_head.side_effect = httpx.RequestError("Connection error") @@ -93,12 +92,12 @@ class TestTraceClient: with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"): client.api_check() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_get_project_url(self, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm" - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_add_span(self, mock_exporter_class, trace_client_factory): client = trace_client_factory( service_name="test-service", @@ -134,8 +133,8 @@ class TestTraceClient: assert len(client.queue) == 2 mock_notify.assert_called_once() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.logger") def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory): client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1) @@ -159,7 +158,7 @@ class TestTraceClient: assert len(client.queue) == 1 mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.") - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_export_batch_error(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value mock_exporter.export.side_effect = Exception("Export failed") @@ -168,11 +167,11 @@ class TestTraceClient: mock_span = MagicMock(spec=ReadableSpan) client.queue.append(mock_span) - with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger: + with patch("dify_trace_aliyun.data_exporter.traceclient.logger") as mock_logger: client._export_batch() mock_logger.warning.assert_called() - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_worker_loop(self, mock_exporter_class, trace_client_factory): # We need to test the wait timeout in _worker # But _worker runs in a thread. Let's mock condition.wait. @@ -189,7 +188,7 @@ class TestTraceClient: # mock_wait might have been called assert mock_wait.called or client.done - @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("dify_trace_aliyun.data_exporter.traceclient.OTLPSpanExporter") def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory): mock_exporter = mock_exporter_class.return_value client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") @@ -226,8 +225,10 @@ class TestSpanBuilder: span = builder.build_span(span_data) assert isinstance(span, ReadableSpan) assert span.name == "test-span" + assert span.context is not None assert span.context.trace_id == 123 assert span.context.span_id == 456 + assert span.parent is not None assert span.parent.span_id == 789 assert span.resource == resource assert span.attributes == {"attr1": "val1"} @@ -268,7 +269,7 @@ def test_generate_span_id(): assert span_id != INVALID_SPAN_ID # Test retry loop - with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand: + with patch("dify_trace_aliyun.data_exporter.traceclient.random.getrandbits") as mock_rand: mock_rand.side_effect = [INVALID_SPAN_ID, 999] span_id = generate_span_id() assert span_id == 999 @@ -290,7 +291,7 @@ def test_convert_to_trace_id(): def test_convert_string_to_id(): assert convert_string_to_id("test") > 0 # Test with None string - with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen: + with patch("dify_trace_aliyun.data_exporter.traceclient.generate_span_id") as mock_gen: mock_gen.return_value = 12345 assert convert_string_to_id(None) == 12345 diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py similarity index 90% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py index 2fcb927e0c..a6808fec0a 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -1,11 +1,10 @@ import pytest +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata from opentelemetry import trace as trace_api from opentelemetry.sdk.trace import Event from opentelemetry.trace import SpanKind, Status, StatusCode from pydantic import ValidationError -from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata - class TestTraceMetadata: def test_trace_metadata_init(self): @@ -65,12 +64,13 @@ class TestSpanData: def test_span_data_missing_required_fields(self): with pytest.raises(ValidationError): - SpanData( - trace_id=123, - # span_id missing - name="test_span", - start_time=1000, - end_time=2000, + SpanData.model_validate( + { + "trace_id": 123, + "name": "test_span", + "start_time": 1000, + "end_time": 2000, + } ) def test_span_data_arbitrary_types_allowed(self): diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py similarity index 97% rename from api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py index 3961555b9a..9cab40748f 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/entities/test_semconv.py @@ -1,4 +1,4 @@ -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( ACS_ARMS_SERVICE_FEATURE, GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py similarity index 90% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py index 62d631a754..fa00829653 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace.py @@ -2,16 +2,15 @@ from __future__ import annotations from datetime import UTC, datetime from types import SimpleNamespace +from typing import cast from unittest.mock import MagicMock +import dify_trace_aliyun.aliyun_trace as aliyun_trace_module import pytest -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey -from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags - -import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module -from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.aliyun_trace import AliyunDataTrace +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_aliyun.entities.aliyun_trace_entity import SpanData, TraceMetadata +from dify_trace_aliyun.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_INPUT_MESSAGE, GEN_AI_OUTPUT_MESSAGE, @@ -26,7 +25,8 @@ from core.ops.aliyun_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.entities.config_entity import AliyunConfig +from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -36,6 +36,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: @@ -44,7 +46,7 @@ class RecordingTraceClient: self.endpoint = endpoint self.added_spans: list[object] = [] - def add_span(self, span) -> None: + def add_span(self, span: object) -> None: self.added_spans.append(span) def api_check(self) -> bool: @@ -63,11 +65,35 @@ def _make_link(trace_id: int = 1, span_id: int = 2) -> Link: trace_id=trace_id, span_id=span_id, is_remote=False, - trace_flags=TraceFlags.SAMPLED, + trace_flags=TraceFlags(TraceFlags.SAMPLED), ) return Link(context) +def _make_trace_metadata( + trace_id: int = 1, + workflow_span_id: int = 2, + session_id: str = "s", + user_id: str = "u", + links: list[Link] | None = None, +) -> TraceMetadata: + return TraceMetadata( + trace_id=trace_id, + workflow_span_id=workflow_span_id, + session_id=session_id, + user_id=user_id, + links=[] if links is None else links, + ) + + +def _recording_trace_client(trace_instance: AliyunDataTrace) -> RecordingTraceClient: + return cast(RecordingTraceClient, trace_instance.trace_client) + + +def _recorded_span_data(trace_instance: AliyunDataTrace) -> list[SpanData]: + return cast(list[SpanData], _recording_trace_client(trace_instance).added_spans) + + def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: defaults = { "workflow_id": "workflow-id", @@ -263,20 +289,20 @@ def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataT trace_instance.workflow_trace(trace_info) add_workflow_span.assert_called_once() - passed_trace_metadata = add_workflow_span.call_args.args[1] + passed_trace_metadata = cast(TraceMetadata, add_workflow_span.call_args.args[1]) assert passed_trace_metadata.trace_id == 111 assert passed_trace_metadata.workflow_span_id == 222 assert passed_trace_metadata.session_id == "c" assert passed_trace_metadata.user_id == "u" assert passed_trace_metadata.links == [] - assert trace_instance.trace_client.added_spans == ["span-1", "span-2"] + assert _recording_trace_client(trace_instance).added_spans == ["span-1", "span-2"] def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_message_trace_info(message_data=None) trace_instance.message_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -302,8 +328,9 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT ) trace_instance.message_trace(trace_info) - assert len(trace_instance.trace_client.added_spans) == 2 - message_span, llm_span = trace_instance.trace_client.added_spans + spans = _recorded_span_data(trace_instance) + assert len(spans) == 2 + message_span, llm_span = spans assert message_span.name == "message" assert message_span.trace_id == 10 @@ -324,7 +351,7 @@ def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataT def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_dataset_retrieval_trace_info(message_data=None) trace_instance.dataset_retrieval_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -338,8 +365,9 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}]) trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query")) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "dataset_retrieval" assert span.attributes[RETRIEVAL_QUERY] == "query" assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]' @@ -348,7 +376,7 @@ def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, m def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): trace_info = _make_tool_trace_info(message_data=None) trace_instance.tool_trace(trace_info) - assert trace_instance.trace_client.added_spans == [] + assert _recording_trace_client(trace_instance).added_spans == [] def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): @@ -371,8 +399,9 @@ def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: p ) ) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "my-tool" assert span.status == status assert span.attributes[TOOL_NAME] == "my-tool" @@ -409,7 +438,7 @@ def test_get_workflow_node_executions_builds_repo_and_fetches( def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm")) @@ -422,7 +451,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type( ): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval")) @@ -433,7 +462,7 @@ def test_build_workflow_node_span_routes_knowledge_retrieval_type( def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool")) @@ -444,7 +473,7 @@ def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTra def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task")) @@ -457,7 +486,7 @@ def test_build_workflow_node_span_handles_errors( ): node_execution = MagicMock(spec=WorkflowNodeExecution) trace_info = _make_workflow_trace_info() - trace_metadata = MagicMock() + trace_metadata = _make_trace_metadata() monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom"))) node_execution.node_type = BuiltinNodeTypes.CODE @@ -472,7 +501,7 @@ def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch: status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "title" @@ -494,7 +523,7 @@ def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch: status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()]) + trace_metadata = _make_trace_metadata(links=[_make_link()]) node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "my-tool" @@ -527,7 +556,7 @@ def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypa aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else [] ) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "retrieval" @@ -556,7 +585,7 @@ def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: p monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in") monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out") - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() node_execution = MagicMock(spec=WorkflowNodeExecution) node_execution.id = "node-id" node_execution.title = "llm" @@ -594,7 +623,7 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. status = Status(StatusCode.OK) monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) - trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + trace_metadata = _make_trace_metadata() # CASE 1: With message_id trace_info = _make_workflow_trace_info( @@ -602,9 +631,11 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. ) trace_instance.add_workflow_span(trace_info, trace_metadata) - assert len(trace_instance.trace_client.added_spans) == 2 - message_span = trace_instance.trace_client.added_spans[0] - workflow_span = trace_instance.trace_client.added_spans[1] + client = _recording_trace_client(trace_instance) + spans = _recorded_span_data(trace_instance) + assert len(spans) == 2 + message_span = spans[0] + workflow_span = spans[1] assert message_span.name == "message" assert message_span.span_kind == SpanKind.SERVER @@ -614,13 +645,14 @@ def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest. assert workflow_span.span_kind == SpanKind.INTERNAL assert workflow_span.parent_span_id == 20 - trace_instance.trace_client.added_spans.clear() + client.added_spans.clear() # CASE 2: Without message_id trace_info_no_msg = _make_workflow_trace_info(message_id=None) trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "workflow" assert span.span_kind == SpanKind.SERVER assert span.parent_span_id is None @@ -641,7 +673,8 @@ def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch: trace_info = _make_suggested_question_trace_info(suggested_question=["how?"]) trace_instance.suggested_question_trace(trace_info) - assert len(trace_instance.trace_client.added_spans) == 1 - span = trace_instance.trace_client.added_spans[0] + spans = _recorded_span_data(trace_instance) + assert len(spans) == 1 + span = spans[0] assert span.name == "suggested_question" assert span.attributes[GEN_AI_COMPLETION] == '["how?"]' diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py similarity index 92% rename from api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py rename to api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py index 2d2be12f05..1b97746dea 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/aliyun_trace/test_aliyun_trace_utils.py @@ -1,11 +1,9 @@ import json +from collections.abc import Mapping +from typing import Any, cast from unittest.mock import MagicMock -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionStatus -from opentelemetry.trace import Link, StatusCode - -from core.ops.aliyun_trace.entities.semconv import ( +from dify_trace_aliyun.entities.semconv import ( GEN_AI_FRAMEWORK, GEN_AI_SESSION_ID, GEN_AI_SPAN_KIND, @@ -13,7 +11,7 @@ from core.ops.aliyun_trace.entities.semconv import ( INPUT_VALUE, OUTPUT_VALUE, ) -from core.ops.aliyun_trace.utils import ( +from dify_trace_aliyun.utils import ( create_common_span_attributes, create_links_from_trace_id, create_status_from_error, @@ -25,7 +23,11 @@ from core.ops.aliyun_trace.utils import ( get_workflow_node_status, serialize_json_data, ) +from opentelemetry.trace import Link, StatusCode + from core.rag.models.document import Document +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from models import EndUser @@ -48,7 +50,7 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = end_user_data - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -63,7 +65,7 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): mock_session = MagicMock() mock_session.get.return_value = None - from core.ops.aliyun_trace.utils import db + from dify_trace_aliyun.utils import db monkeypatch.setattr(db, "session", mock_session) @@ -112,9 +114,9 @@ def test_get_workflow_node_status(): def test_create_links_from_trace_id(monkeypatch): # Mock create_link mock_link = MagicMock(spec=Link) - import core.ops.aliyun_trace.data_exporter.traceclient + import dify_trace_aliyun.data_exporter.traceclient - monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) + monkeypatch.setattr(dify_trace_aliyun.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) # Trace ID None assert create_links_from_trace_id(None) == [] @@ -170,7 +172,7 @@ def test_create_common_span_attributes(): def test_format_retrieval_documents(): # Not a list - assert format_retrieval_documents("not a list") == [] + assert format_retrieval_documents(cast(list[object], "not a list")) == [] # Valid list docs = [ @@ -211,7 +213,7 @@ def test_format_retrieval_documents(): def test_format_input_messages(): # Not a dict - assert format_input_messages(None) == serialize_json_data([]) + assert format_input_messages(cast(Mapping[str, Any], None)) == serialize_json_data([]) # No prompts assert format_input_messages({}) == serialize_json_data([]) @@ -244,7 +246,7 @@ def test_format_input_messages(): def test_format_output_messages(): # Not a dict - assert format_output_messages(None) == serialize_json_data([]) + assert format_output_messages(cast(Mapping[str, Any], None)) == serialize_json_data([]) # No text assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([]) diff --git a/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..8068ee1328 --- /dev/null +++ b/api/providers/trace/trace-aliyun/tests/unit_tests/test_config_entity.py @@ -0,0 +1,85 @@ +import pytest +from dify_trace_aliyun.config import AliyunConfig +from pydantic import ValidationError + + +class TestAliyunConfig: + """Test cases for AliyunConfig""" + + def test_valid_config(self): + """Test valid Aliyun configuration""" + config = AliyunConfig( + app_name="test_app", + license_key="test_license_key", + endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", + ) + assert config.app_name == "test_app" + assert config.license_key == "test_license_key" + assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + assert config.app_name == "dify_app" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + AliyunConfig.model_validate({}) + + with pytest.raises(ValidationError): + AliyunConfig.model_validate({"license_key": "test_license"}) + + with pytest.raises(ValidationError): + AliyunConfig.model_validate({"endpoint": "https://tracing-analysis-dc-hz.aliyuncs.com"}) + + def test_app_name_validation_empty(self): + """Test app_name validation with empty value""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" + ) + assert config.app_name == "dify_app" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = AliyunConfig(license_key="test_license", endpoint="") + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation preserves path for Aliyun endpoints""" + config = AliyunConfig( + license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + ) + assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") + + def test_license_key_required(self): + """Test that license_key is required and cannot be empty""" + with pytest.raises(ValidationError): + AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") + + def test_valid_endpoint_format_examples(self): + """Test valid endpoint format examples from comments""" + valid_endpoints = [ + # cms2.0 public endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", + # cms2.0 intranet endpoint + "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", + # xtrace public endpoint + "http://tracing-cn-heyuan.arms.aliyuncs.com", + # xtrace intranet endpoint + "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", + ] + + for endpoint in valid_endpoints: + config = AliyunConfig(license_key="test_license", endpoint=endpoint) + assert config.endpoint == endpoint diff --git a/api/providers/trace/trace-arize-phoenix/pyproject.toml b/api/providers/trace/trace-arize-phoenix/pyproject.toml new file mode 100644 index 0000000000..9e756944c9 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-arize-phoenix" +version = "0.0.1" +dependencies = [ + "arize-phoenix-otel~=0.15.0", +] +description = "Dify ops tracing provider (Arize / Phoenix)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langfuse_trace/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py similarity index 100% rename from api/core/ops/langfuse_trace/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/__init__.py diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py similarity index 99% rename from api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index 66933cea28..96df49ed0e 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -6,7 +6,6 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse -from graphon.enums import WorkflowNodeExecutionStatus from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -26,7 +25,6 @@ from opentelemetry.util.types import AttributeValue from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -40,7 +38,9 @@ from core.ops.entities.trace_entity import ( ) from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -778,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True) raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}") - def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: + def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]: """Construct LLM attributes with passed prompts for Arize/Phoenix.""" attributes: dict[str, str] = {} @@ -797,7 +797,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance): path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}" set_attribute(path, value) - def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None: + def set_tool_call_attributes( + message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None + ) -> None: """Extract and assign tool call details safely.""" if not tool_call: return diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py new file mode 100644 index 0000000000..6eac5b30d2 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/config.py @@ -0,0 +1,45 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class ArizeConfig(BaseTracingConfig): + """ + Model class for Arize tracing config. + """ + + api_key: str | None = None + space_id: str | None = None + project: str | None = None + endpoint: str = "https://otlp.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://otlp.arize.com") + + +class PhoenixConfig(BaseTracingConfig): + """ + Model class for Phoenix tracing config. + """ + + api_key: str | None = None + project: str | None = None + endpoint: str = "https://app.phoenix.arize.com" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "default") + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://app.phoenix.arize.com") diff --git a/api/core/ops/langfuse_trace/entities/__init__.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed similarity index 100% rename from api/core/ops/langfuse_trace/entities/__init__.py rename to api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/py.typed diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index 4ce9e22fd7..e9ecc2e083 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -1,12 +1,9 @@ from datetime import UTC, datetime, timedelta +from typing import cast from unittest.mock import MagicMock, patch import pytest -from opentelemetry.sdk.trace import Tracer -from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes -from opentelemetry.trace import StatusCode - -from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( +from dify_trace_arize_phoenix.arize_phoenix_trace import ( ArizePhoenixDataTrace, datetime_to_nanos, error_to_string, @@ -15,7 +12,11 @@ from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( setup_tracer, wrap_span_metadata, ) -from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from opentelemetry.sdk.trace import Tracer +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.trace import StatusCode + from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -80,7 +81,7 @@ def test_datetime_to_nanos(): expected = int(dt.timestamp() * 1_000_000_000) assert datetime_to_nanos(dt) == expected - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.datetime") as mock_dt: mock_now = MagicMock() mock_now.timestamp.return_value = 1704110400.0 mock_dt.now.return_value = mock_now @@ -129,7 +130,7 @@ def test_set_span_status(): return "SilentErrorRepr" span.reset_mock() - set_span_status(span, SilentError()) + set_span_status(span, cast(Exception | str | None, SilentError())) assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr" @@ -142,8 +143,8 @@ def test_wrap_span_metadata(): assert res == {"a": 1, "b": 2, "created_from": "Dify"} -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_arize(mock_provider, mock_exporter): config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") setup_tracer(config) @@ -151,8 +152,8 @@ def test_setup_tracer_arize(mock_provider, mock_exporter): assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1" -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.HttpOTLPSpanExporter") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_phoenix(mock_provider, mock_exporter): config = PhoenixConfig(endpoint="http://p.com", project="p") setup_tracer(config) @@ -162,7 +163,7 @@ def test_setup_tracer_phoenix(mock_provider, mock_exporter): def test_setup_tracer_exception(): config = ArizeConfig(endpoint="http://a.com", project="p") - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): with pytest.raises(Exception, match="boom"): setup_tracer(config) @@ -172,7 +173,7 @@ def test_setup_tracer_exception(): @pytest.fixture def trace_instance(): - with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup: + with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup: mock_tracer = MagicMock(spec=Tracer) mock_processor = MagicMock() mock_setup.return_value = (mock_tracer, mock_processor) @@ -228,9 +229,9 @@ def test_trace_exception(trace_instance): trace_instance.trace(_make_workflow_info()) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory") -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -262,7 +263,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac assert trace_instance.tracer.start_span.call_count >= 2 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_workflow_trace_no_app_id(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_workflow_info() @@ -271,7 +272,7 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance): trace_instance.workflow_trace(info) -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_success(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() @@ -291,7 +292,7 @@ def test_message_trace_success(mock_db, trace_instance): assert trace_instance.tracer.start_span.call_count >= 1 -@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") def test_message_trace_with_error(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py similarity index 94% rename from api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py rename to api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py index 6b5cb5b09a..a01c63ae61 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ -from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes +from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind from openinference.semconv.trace import OpenInferenceSpanKindValues -from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..11e951c3b1 --- /dev/null +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_config_entity.py @@ -0,0 +1,88 @@ +import pytest +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from pydantic import ValidationError + + +class TestArizeConfig: + """Test cases for ArizeConfig""" + + def test_valid_config(self): + """Test valid Arize configuration""" + config = ArizeConfig( + api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" + ) + assert config.api_key == "test_key" + assert config.space_id == "test_space" + assert config.project == "test_project" + assert config.endpoint == "https://custom.arize.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = ArizeConfig() + assert config.api_key is None + assert config.space_id is None + assert config.project is None + assert config.endpoint == "https://otlp.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = ArizeConfig(project="") + assert config.project == "default" + + def test_project_validation_none(self): + """Test project validation with None value""" + config = ArizeConfig(project=None) + assert config.project == "default" + + def test_endpoint_validation_empty(self): + """Test endpoint validation with empty value""" + config = ArizeConfig(endpoint="") + assert config.endpoint == "https://otlp.arize.com" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation normalizes URL by removing path""" + config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") + assert config.endpoint == "https://custom.arize.com" + + def test_endpoint_validation_invalid_scheme(self): + """Test endpoint validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="ftp://invalid.com") + + def test_endpoint_validation_no_scheme(self): + """Test endpoint validation rejects URLs without scheme""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + ArizeConfig(endpoint="invalid.com") + + +class TestPhoenixConfig: + """Test cases for PhoenixConfig""" + + def test_valid_config(self): + """Test valid Phoenix configuration""" + config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.phoenix.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = PhoenixConfig() + assert config.api_key is None + assert config.project is None + assert config.endpoint == "https://app.phoenix.arize.com" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = PhoenixConfig(project="") + assert config.project == "default" + + def test_endpoint_validation_with_path(self): + """Test endpoint validation with path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") + assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" + + def test_endpoint_validation_without_path(self): + """Test endpoint validation without path""" + config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") + assert config.endpoint == "https://app.phoenix.arize.com" diff --git a/api/providers/trace/trace-langfuse/pyproject.toml b/api/providers/trace/trace-langfuse/pyproject.toml new file mode 100644 index 0000000000..27d2273a69 --- /dev/null +++ b/api/providers/trace/trace-langfuse/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langfuse" +version = "0.0.1" +dependencies = [ + "langfuse>=4.2.0,<5.0.0", +] +description = "Dify ops tracing provider (Langfuse)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/langsmith_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/__init__.py diff --git a/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py new file mode 100644 index 0000000000..90d1a2846b --- /dev/null +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/config.py @@ -0,0 +1,19 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class LangfuseConfig(BaseTracingConfig): + """ + Model class for Langfuse tracing config. + """ + + public_key: str + secret_key: str + host: str = "https://api.langfuse.com" + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://api.langfuse.com") diff --git a/api/core/ops/langsmith_trace/entities/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/__init__.py diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py similarity index 100% rename from api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/entities/langfuse_trace_entity.py diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py similarity index 93% rename from api/core/ops/langfuse_trace/langfuse_trace.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py index 9be2ce1bdf..68881378a7 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/langfuse_trace.py @@ -3,7 +3,6 @@ import os import uuid from datetime import UTC, datetime, timedelta -from graphon.enums import BuiltinNodeTypes from langfuse import Langfuse from langfuse.api import ( CreateGenerationBody, @@ -17,7 +16,6 @@ from langfuse.api.commons.types.usage import Usage from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -29,7 +27,10 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( +from core.ops.utils import filter_none_values +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( GenerationUsage, LangfuseGeneration, LangfuseSpan, @@ -37,9 +38,8 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( LevelEnum, UnitEnum, ) -from core.ops.utils import filter_none_values -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus @@ -59,6 +59,24 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") + @staticmethod + def _get_completion_start_time( + start_time: datetime | None, time_to_first_token: float | int | None + ) -> datetime | None: + """Convert a relative TTFT value in seconds into Langfuse's absolute completion start time.""" + if start_time is None or time_to_first_token is None: + return None + + try: + ttft_seconds = float(time_to_first_token) + except (TypeError, ValueError): + return None + + if ttft_seconds < 0: + return None + + return start_time + timedelta(seconds=ttft_seconds) + def trace(self, trace_info: BaseTraceInfo): if isinstance(trace_info, WorkflowTraceInfo): self.workflow_trace(trace_info) @@ -189,10 +207,18 @@ class LangFuseDataTrace(BaseTraceInstance): total_token = metadata.get("total_tokens", 0) prompt_tokens = 0 completion_tokens = 0 + completion_start_time = None try: - usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + usage_data = process_data.get("usage") + if not isinstance(usage_data, dict): + usage_data = outputs.get("usage") + if not isinstance(usage_data, dict): + usage_data = {} prompt_tokens = usage_data.get("prompt_tokens", 0) completion_tokens = usage_data.get("completion_tokens", 0) + completion_start_time = self._get_completion_start_time( + created_at, usage_data.get("time_to_first_token") + ) except Exception: logger.error("Failed to extract usage", exc_info=True) @@ -210,6 +236,7 @@ class LangFuseDataTrace(BaseTraceInstance): trace_id=trace_id, model=process_data.get("model_name"), start_time=created_at, + completion_start_time=completion_start_time, end_time=finished_at, input=inputs, output=outputs, @@ -290,11 +317,16 @@ class LangFuseDataTrace(BaseTraceInstance): unit=UnitEnum.TOKENS, totalCost=message_data.total_price, ) + completion_start_time = self._get_completion_start_time( + trace_info.start_time, + trace_info.gen_ai_server_time_to_first_token, + ) langfuse_generation_data = LangfuseGeneration( name="llm", trace_id=trace_id, start_time=trace_info.start_time, + completion_start_time=completion_start_time, end_time=trace_info.end_time, model=message_data.model_id, input=trace_info.inputs, diff --git a/api/core/ops/mlflow_trace/__init__.py b/api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed similarity index 100% rename from api/core/ops/mlflow_trace/__init__.py rename to api/providers/trace/trace-langfuse/src/dify_trace_langfuse/py.typed diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py rename to api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py index 374371fb42..952f10c34f 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/langfuse_trace/test_langfuse_trace.py @@ -5,9 +5,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.entities.langfuse_trace_entity import ( + LangfuseGeneration, + LangfuseSpan, + LangfuseTrace, + LevelEnum, + UnitEnum, +) +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace -from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -18,14 +25,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( - LangfuseGeneration, - LangfuseSpan, - LangfuseTrace, - LevelEnum, - UnitEnum, -) -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -43,7 +43,7 @@ def langfuse_config(): def trace_instance(langfuse_config, monkeypatch): # Mock Langfuse client to avoid network calls mock_client = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", lambda **kwargs: mock_client) instance = LangFuseDataTrace(langfuse_config) return instance @@ -51,7 +51,7 @@ def trace_instance(langfuse_config, monkeypatch): def test_init(langfuse_config, monkeypatch): mock_langfuse = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.Langfuse", mock_langfuse) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangFuseDataTrace(langfuse_config) @@ -140,8 +140,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): # Mock DB and Repositories mock_session = MagicMock() - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -178,7 +178,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -241,13 +241,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -280,8 +280,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="log-1", error="", ) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -365,7 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() @@ -681,9 +681,9 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langfuse.langfuse_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..0c3c3fc81e --- /dev/null +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_config_entity.py @@ -0,0 +1,42 @@ +import pytest +from dify_trace_langfuse.config import LangfuseConfig +from pydantic import ValidationError + + +class TestLangfuseConfig: + """Test cases for LangfuseConfig""" + + def test_valid_config(self): + """Test valid Langfuse configuration""" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == "https://custom.langfuse.com" + + def test_valid_config_with_path(self): + host = "https://custom.langfuse.com/api/v1" + config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) + assert config.public_key == "public_key" + assert config.secret_key == "secret_key" + assert config.host == host + + def test_default_values(self): + """Test default values are set correctly""" + config = LangfuseConfig(public_key="public", secret_key="secret") + assert config.host == "https://api.langfuse.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangfuseConfig.model_validate({}) + + with pytest.raises(ValidationError): + LangfuseConfig.model_validate({"public_key": "public"}) + + with pytest.raises(ValidationError): + LangfuseConfig.model_validate({"secret_key": "secret"}) + + def test_host_validation_empty(self): + """Test host validation with empty value""" + config = LangfuseConfig(public_key="public", secret_key="secret", host="") + assert config.host == "https://api.langfuse.com" diff --git a/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py new file mode 100644 index 0000000000..82d69b6180 --- /dev/null +++ b/api/providers/trace/trace-langfuse/tests/unit_tests/test_langfuse_trace.py @@ -0,0 +1,138 @@ +"""Tests for Langfuse TTFT reporting support.""" + +from datetime import datetime, timedelta +from types import SimpleNamespace +from typing import cast +from unittest.mock import MagicMock, patch + +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace + +from core.ops.entities.trace_entity import MessageTraceInfo, WorkflowTraceInfo +from graphon.enums import BuiltinNodeTypes + + +def _create_trace_instance() -> LangFuseDataTrace: + with patch("dify_trace_langfuse.langfuse_trace.Langfuse", autospec=True): + return LangFuseDataTrace( + LangfuseConfig( + public_key="public-key", + secret_key="secret-key", + host="https://cloud.langfuse.com", + ) + ) + + +class TestLangFuseDataTraceCompletionStartTime: + def test_message_trace_reports_completion_start_time(self): + trace = _create_trace_instance() + start_time = datetime(2026, 3, 11, 13, 0, 0) + trace_info = MessageTraceInfo( + trace_id="trace-123", + message_id="message-123", + message_data=SimpleNamespace( + id="message-123", + from_account_id="account-1", + from_end_user_id=None, + conversation_id="conversation-1", + model_id="gpt-4o-mini", + answer="hi there", + status="normal", + error="", + total_price=0.12, + provider_response_latency=3.5, + ), + conversation_model="chat", + message_tokens=10, + answer_tokens=20, + total_tokens=30, + error="", + inputs="hello", + outputs="hi there", + file_list=[], + start_time=start_time, + end_time=start_time + timedelta(seconds=3.5), + metadata={}, + message_file_data=None, + conversation_mode="chat", + gen_ai_server_time_to_first_token=1.2, + llm_streaming_time_to_generate=2.3, + is_streaming_request=True, + ) + + with patch.object(trace, "add_trace"), patch.object(trace, "add_generation") as add_generation: + trace.message_trace(trace_info) + + generation = add_generation.call_args.args[0] + assert generation.completion_start_time == start_time + timedelta(seconds=1.2) + + def test_workflow_trace_reports_completion_start_time_from_llm_usage(self): + trace = _create_trace_instance() + start_time = datetime(2026, 3, 11, 13, 0, 0) + node_execution = SimpleNamespace( + id="node-exec-1", + title="Chat LLM", + node_type=BuiltinNodeTypes.LLM, + status="succeeded", + process_data={ + "model_mode": "chat", + "model_name": "gpt-4o-mini", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "time_to_first_token": 1.2, + }, + }, + inputs={"question": "hello"}, + outputs={"text": "hi there"}, + created_at=start_time, + elapsed_time=3.5, + metadata={}, + ) + trace_info = WorkflowTraceInfo( + trace_id="trace-123", + workflow_data={}, + conversation_id=None, + workflow_app_log_id=None, + workflow_id="workflow-1", + tenant_id="tenant-1", + workflow_run_id="workflow-run-1", + workflow_run_elapsed_time=3.5, + workflow_run_status="succeeded", + workflow_run_inputs={"question": "hello"}, + workflow_run_outputs={"answer": "hi there"}, + workflow_run_version="1", + error="", + total_tokens=30, + file_list=[], + query="hello", + metadata={"app_id": "app-1", "user_id": "user-1"}, + start_time=start_time, + end_time=start_time + timedelta(seconds=3.5), + ) + repository = MagicMock() + repository.get_by_workflow_execution.return_value = [node_execution] + + with ( + patch.object(trace, "add_trace"), + patch.object(trace, "add_span"), + patch.object(trace, "add_generation") as add_generation, + patch.object(trace, "get_service_account_with_tenant", return_value=MagicMock()), + patch("dify_trace_langfuse.langfuse_trace.db", MagicMock()), + patch( + "dify_trace_langfuse.langfuse_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + return_value=repository, + ), + ): + trace.workflow_trace(trace_info) + + generation = add_generation.call_args.kwargs["langfuse_generation_data"] + assert generation.completion_start_time == start_time + timedelta(seconds=1.2) + + def test_ignores_invalid_ttft_values(self): + trace = _create_trace_instance() + start_time = datetime(2026, 3, 11, 13, 0, 0) + + assert trace._get_completion_start_time(start_time, None) is None + assert trace._get_completion_start_time(start_time, -1) is None + assert trace._get_completion_start_time(start_time, cast(float | int | None, "invalid")) is None diff --git a/api/providers/trace/trace-langsmith/pyproject.toml b/api/providers/trace/trace-langsmith/pyproject.toml new file mode 100644 index 0000000000..8131952b28 --- /dev/null +++ b/api/providers/trace/trace-langsmith/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-langsmith" +version = "0.0.1" +dependencies = [ + "langsmith~=0.7.30", +] +description = "Dify ops tracing provider (LangSmith)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/opik_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py similarity index 100% rename from api/core/ops/opik_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/__init__.py diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py new file mode 100644 index 0000000000..498b8c5e7e --- /dev/null +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/config.py @@ -0,0 +1,20 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class LangSmithConfig(BaseTracingConfig): + """ + Model class for Langsmith tracing config. + """ + + api_key: str + project: str + endpoint: str = "https://api.smith.langchain.com" + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # LangSmith only allows HTTPS + return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",)) diff --git a/api/core/ops/tencent_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/__init__.py diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py similarity index 100% rename from api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py similarity index 99% rename from api/core/ops/langsmith_trace/langsmith_trace.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py index 490c64af84..145bd70dbc 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py @@ -4,13 +4,11 @@ import uuid from datetime import datetime, timedelta from typing import cast -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from langsmith import Client from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -22,14 +20,16 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( +from core.ops.utils import filter_none_values, generate_dotted_order +from core.repositories import DifyCoreRepositoryFactory +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( LangSmithRunModel, LangSmithRunType, LangSmithRunUpdateModel, ) -from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/core/ops/weave_trace/__init__.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed similarity index 100% rename from api/core/ops/weave_trace/__init__.py rename to api/providers/trace/trace-langsmith/src/dify_trace_langsmith/py.typed diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py similarity index 91% rename from api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py rename to api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py index bfe916f018..45e5894e4a 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py @@ -3,9 +3,14 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_langsmith.entities.langsmith_trace_entity import ( + LangSmithRunModel, + LangSmithRunType, + LangSmithRunUpdateModel, +) +from dify_trace_langsmith.langsmith_trace import LangSmithDataTrace -from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -16,12 +21,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( - LangSmithRunModel, - LangSmithRunType, - LangSmithRunUpdateModel, -) -from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -38,7 +38,7 @@ def langsmith_config(): def trace_instance(langsmith_config, monkeypatch): # Mock LangSmith client mock_client = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", lambda **kwargs: mock_client) instance = LangSmithDataTrace(langsmith_config) return instance @@ -46,7 +46,7 @@ def trace_instance(langsmith_config, monkeypatch): def test_init(langsmith_config, monkeypatch): mock_client_class = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.Client", mock_client_class) monkeypatch.setenv("FILES_URL", "http://test.url") instance = LangSmithDataTrace(langsmith_config) @@ -138,8 +138,8 @@ def test_workflow_trace(trace_instance, monkeypatch): # Mock dependencies mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) # Mock node executions node_llm = MagicMock() @@ -188,7 +188,7 @@ def test_workflow_trace(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -252,13 +252,13 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() @@ -283,8 +283,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): trace_info.error = "" mock_session = MagicMock() - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -319,7 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() @@ -567,9 +567,9 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_run = MagicMock() diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..bd226c9f1a --- /dev/null +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/test_config_entity.py @@ -0,0 +1,35 @@ +import pytest +from dify_trace_langsmith.config import LangSmithConfig +from pydantic import ValidationError + + +class TestLangSmithConfig: + """Test cases for LangSmithConfig""" + + def test_valid_config(self): + """Test valid LangSmith configuration""" + config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.endpoint == "https://custom.smith.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = LangSmithConfig(api_key="key", project="project") + assert config.endpoint == "https://api.smith.langchain.com" + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + LangSmithConfig.model_validate({}) + + with pytest.raises(ValidationError): + LangSmithConfig.model_validate({"api_key": "key"}) + + with pytest.raises(ValidationError): + LangSmithConfig.model_validate({"project": "project"}) + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") diff --git a/api/providers/trace/trace-mlflow/pyproject.toml b/api/providers/trace/trace-mlflow/pyproject.toml new file mode 100644 index 0000000000..fad6002944 --- /dev/null +++ b/api/providers/trace/trace-mlflow/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-mlflow" +version = "0.0.1" +dependencies = [ + "mlflow-skinny>=3.11.1", +] +description = "Dify ops tracing provider (MLflow / Databricks)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/ops/weave_trace/entities/__init__.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py similarity index 100% rename from api/core/ops/weave_trace/entities/__init__.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/__init__.py diff --git a/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py new file mode 100644 index 0000000000..84914165e3 --- /dev/null +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/config.py @@ -0,0 +1,46 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_integer_id, validate_url_with_path + + +class MLflowConfig(BaseTracingConfig): + """ + Model class for MLflow tracing config. + """ + + tracking_uri: str = "http://localhost:5000" + experiment_id: str = "0" # Default experiment id in MLflow is 0 + username: str | None = None + password: str | None = None + + @field_validator("tracking_uri") + @classmethod + def tracking_uri_validator(cls, v, info: ValidationInfo): + if isinstance(v, str) and v.startswith("databricks"): + raise ValueError( + "Please use Databricks tracing config below to record traces to Databricks-managed MLflow instances." + ) + return validate_url_with_path(v, "http://localhost:5000") + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) + + +class DatabricksConfig(BaseTracingConfig): + """ + Model class for Databricks (Databricks-managed MLflow) tracing config. + """ + + experiment_id: str + host: str + client_id: str | None = None + client_secret: str | None = None + personal_access_token: str | None = None + + @field_validator("experiment_id") + @classmethod + def experiment_id_validator(cls, v, info: ValidationInfo): + return validate_integer_id(v) diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py similarity index 98% rename from api/core/ops/mlflow_trace/mlflow_trace.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py index 3d8c1dd038..4e4c45a532 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py @@ -4,7 +4,6 @@ from datetime import datetime, timedelta from typing import Any, cast import mlflow -from graphon.enums import BuiltinNodeTypes from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace @@ -12,7 +11,6 @@ from mlflow.tracing.provider import detach_span_from_context, set_span_in_contex from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -25,7 +23,9 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.utils import JSON_DICT_ADAPTER +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes from models import EndUser from models.workflow import WorkflowNodeExecutionModel @@ -242,7 +242,7 @@ class MLflowDataTrace(BaseTraceInstance): return inputs, attributes - def _parse_knowledge_retrieval_outputs(self, outputs: dict): + def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]): """Parse KR outputs and attributes from KR workflow node""" retrieved = outputs.get("result", []) @@ -319,7 +319,7 @@ class MLflowDataTrace(BaseTraceInstance): end_time_ns=datetime_to_nanoseconds(trace_info.end_time), ) - def _get_message_user_id(self, metadata: dict) -> str | None: + def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None: if (end_user_id := metadata.get("from_end_user_id")) and ( end_user_data := db.session.get(EndUser, end_user_id) ): @@ -468,7 +468,7 @@ class MLflowDataTrace(BaseTraceInstance): } return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload] - def _set_trace_metadata(self, span: Span, metadata: dict): + def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]): token = None try: # NB: Set span in context such that we can use update_current_trace() API @@ -490,7 +490,7 @@ class MLflowDataTrace(BaseTraceInstance): return messages return prompts # Fallback to original format - def _parse_single_message(self, item: dict): + def _parse_single_message(self, item: dict[str, Any]): """Postprocess single message format to be standard chat message""" role = item.get("role", "user") msg = {"role": role, "content": item.get("text", "")} diff --git a/api/core/rag/datasource/vdb/alibabacloud_mysql/__init__.py b/api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed similarity index 100% rename from api/core/rag/datasource/vdb/alibabacloud_mysql/__init__.py rename to api/providers/trace/trace-mlflow/src/dify_trace_mlflow/py.typed diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py similarity index 96% rename from api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py rename to api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py index f4c485a9fc..46c9750a5d 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/providers/trace/trace-mlflow/tests/unit_tests/mlflow_trace/test_mlflow_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module.""" +"""Comprehensive tests for dify_trace_mlflow.mlflow_trace module.""" from __future__ import annotations @@ -9,9 +9,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes +from dify_trace_mlflow.config import DatabricksConfig, MLflowConfig +from dify_trace_mlflow.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds +from graphon.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── @@ -179,7 +179,7 @@ def _make_node(**overrides): @pytest.fixture def mock_mlflow(): - with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock: + with patch("dify_trace_mlflow.mlflow_trace.mlflow") as mock: yield mock @@ -187,10 +187,10 @@ def mock_mlflow(): def mock_tracing(): """Patch all MLflow tracing functions used by the module.""" with ( - patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start, - patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update, - patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set, - patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach, + patch("dify_trace_mlflow.mlflow_trace.start_span_no_context") as mock_start, + patch("dify_trace_mlflow.mlflow_trace.update_current_trace") as mock_update, + patch("dify_trace_mlflow.mlflow_trace.set_span_in_context") as mock_set, + patch("dify_trace_mlflow.mlflow_trace.detach_span_from_context") as mock_detach, ): yield { "start": mock_start, @@ -202,7 +202,7 @@ def mock_tracing(): @pytest.fixture def mock_db(): - with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock: + with patch("dify_trace_mlflow.mlflow_trace.db") as mock: yield mock @@ -599,7 +599,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_instance.message_trace(_make_message_trace_info()) mock_tracing["start"].assert_called_once() @@ -609,7 +608,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info(error="something broke") trace_instance.message_trace(trace_info) @@ -620,7 +618,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None monkeypatch.setenv("FILES_URL", "http://files.test") file_data = SimpleNamespace(url="path/to/file.png") @@ -638,7 +635,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info(file_list=None, message_file_data=None) trace_instance.message_trace(trace_info) @@ -651,7 +647,6 @@ class TestMessageTrace: end_user = MagicMock() end_user.session_id = "session-xyz" - mock_db.session.query.return_value.where.return_value.first.return_value = end_user trace_info = _make_message_trace_info( metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"}, @@ -664,7 +659,6 @@ class TestMessageTrace: span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" - mock_db.session.query.return_value.where.return_value.first.return_value = None trace_info = _make_message_trace_info( metadata={"from_account_id": "acc-1"}, diff --git a/api/providers/trace/trace-opik/pyproject.toml b/api/providers/trace/trace-opik/pyproject.toml new file mode 100644 index 0000000000..874997168e --- /dev/null +++ b/api/providers/trace/trace-opik/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-opik" +version = "0.0.1" +dependencies = [ + "opik~=1.11.2", +] +description = "Dify ops tracing provider (Opik)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/analyticdb/__init__.py b/api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/analyticdb/__init__.py rename to api/providers/trace/trace-opik/src/dify_trace_opik/__init__.py diff --git a/api/providers/trace/trace-opik/src/dify_trace_opik/config.py b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py new file mode 100644 index 0000000000..c16ff1d903 --- /dev/null +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/config.py @@ -0,0 +1,25 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url_with_path + + +class OpikConfig(BaseTracingConfig): + """ + Model class for Opik tracing config. + """ + + api_key: str | None = None + project: str | None = None + workspace: str | None = None + url: str = "https://www.comet.com/opik/api/" + + @field_validator("project") + @classmethod + def project_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "Default Project") + + @field_validator("url") + @classmethod + def url_validator(cls, v, info: ValidationInfo): + return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/") diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py similarity index 98% rename from api/core/ops/opik_trace/opik_trace.py rename to api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py index 2215bdeb33..2d124ac989 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/providers/trace/trace-opik/src/dify_trace_opik/opik_trace.py @@ -3,15 +3,13 @@ import logging import os import uuid from datetime import datetime, timedelta -from typing import cast +from typing import Any, cast -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -24,7 +22,9 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory +from dify_trace_opik.config import OpikConfig from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -436,7 +436,7 @@ class OpikDataTrace(BaseTraceInstance): self.add_span(span_data) - def add_trace(self, opik_trace_data: dict) -> Trace: + def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace: try: trace = self.opik_client.trace(**opik_trace_data) logger.debug("Opik Trace created successfully") @@ -444,7 +444,7 @@ class OpikDataTrace(BaseTraceInstance): except Exception as e: raise ValueError(f"Opik Failed to create trace: {str(e)}") - def add_span(self, opik_span_data: dict): + def add_span(self, opik_span_data: dict[str, Any]): try: self.opik_client.span(**opik_span_data) logger.debug("Opik Span created successfully") diff --git a/api/core/rag/datasource/vdb/baidu/__init__.py b/api/providers/trace/trace-opik/src/dify_trace_opik/py.typed similarity index 100% rename from api/core/rag/datasource/vdb/baidu/__init__.py rename to api/providers/trace/trace-opik/src/dify_trace_opik/py.typed diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py similarity index 93% rename from api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py index 1cb32f2ee0..eefed3c78c 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/opik_trace/test_opik_trace.py @@ -5,9 +5,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from dify_trace_opik.config import OpikConfig +from dify_trace_opik.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -18,7 +18,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -37,7 +37,7 @@ def opik_config(): @pytest.fixture def trace_instance(opik_config, monkeypatch): mock_client = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", lambda **kwargs: mock_client) instance = OpikDataTrace(opik_config) return instance @@ -67,7 +67,7 @@ def test_prepare_opik_uuid(): def test_init(opik_config, monkeypatch): mock_opik = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik) + monkeypatch.setattr("dify_trace_opik.opik_trace.Opik", mock_opik) monkeypatch.setenv("FILES_URL", "http://test.url") instance = OpikDataTrace(opik_config) @@ -166,8 +166,8 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): ) mock_session = MagicMock() - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) node_llm = MagicMock() node_llm.id = LLM_NODE_ID @@ -203,7 +203,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) @@ -250,13 +250,13 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() @@ -286,8 +286,8 @@ def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e", error="", ) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): trace_instance.workflow_trace(trace_info) @@ -373,7 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) + monkeypatch.setattr("dify_trace_opik.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() @@ -658,9 +658,9 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_opik.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_opik.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_opik.opik_trace.db", MagicMock(engine="engine")) monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) trace_instance.add_trace = MagicMock() diff --git a/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..5a54b70bba --- /dev/null +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_config_entity.py @@ -0,0 +1,48 @@ +import pytest +from dify_trace_opik.config import OpikConfig +from pydantic import ValidationError + + +class TestOpikConfig: + """Test cases for OpikConfig""" + + def test_valid_config(self): + """Test valid Opik configuration""" + config = OpikConfig( + api_key="test_key", + project="test_project", + workspace="test_workspace", + url="https://custom.comet.com/opik/api/", + ) + assert config.api_key == "test_key" + assert config.project == "test_project" + assert config.workspace == "test_workspace" + assert config.url == "https://custom.comet.com/opik/api/" + + def test_default_values(self): + """Test default values are set correctly""" + config = OpikConfig() + assert config.api_key is None + assert config.project is None + assert config.workspace is None + assert config.url == "https://www.comet.com/opik/api/" + + def test_project_validation_empty(self): + """Test project validation with empty value""" + config = OpikConfig(project="") + assert config.project == "Default Project" + + def test_url_validation_empty(self): + """Test URL validation with empty value""" + config = OpikConfig(url="") + assert config.url == "https://www.comet.com/opik/api/" + + def test_url_validation_missing_suffix(self): + """Test URL validation requires /api/ suffix""" + with pytest.raises(ValidationError, match="URL should end with /api/"): + OpikConfig(url="https://custom.comet.com/opik/") + + def test_url_validation_invalid_scheme(self): + """Test URL validation rejects invalid schemes""" + with pytest.raises(ValidationError, match="URL must start with https:// or http://"): + OpikConfig(url="ftp://custom.comet.com/opik/api/") diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py similarity index 87% rename from api/tests/unit_tests/core/ops/test_opik_trace.py rename to api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py index ad9d0846be..2e0796c291 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/providers/trace/trace-opik/tests/unit_tests/test_opik_trace.py @@ -12,10 +12,12 @@ from __future__ import annotations import uuid from datetime import datetime +from typing import cast from unittest.mock import MagicMock, patch +from dify_trace_opik.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid + from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo -from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid # A stable UUID4 used as the workflow_run_id throughout all tests. _WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6" @@ -56,8 +58,8 @@ def _make_workflow_trace_info( def _make_opik_trace_instance() -> OpikDataTrace: """Construct an OpikDataTrace with the Opik SDK client mocked out.""" - with patch("core.ops.opik_trace.opik_trace.Opik"): - from core.ops.entities.config_entity import OpikConfig + with patch("dify_trace_opik.opik_trace.Opik"): + from dify_trace_opik.config import OpikConfig config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/") instance = OpikDataTrace(config) @@ -68,6 +70,14 @@ def _make_opik_trace_instance() -> OpikDataTrace: return instance +def _add_trace_mock(instance: OpikDataTrace) -> MagicMock: + return cast(MagicMock, instance.add_trace) + + +def _add_span_mock(instance: OpikDataTrace) -> MagicMock: + return cast(MagicMock, instance.add_span) + + # --------------------------------------------------------------------------- # _seed_to_uuid4 # --------------------------------------------------------------------------- @@ -133,10 +143,10 @@ class TestWorkflowTraceWithoutMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): @@ -154,21 +164,21 @@ class TestWorkflowTraceWithoutMessageId: def test_root_span_is_created(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - assert instance.add_span.called + assert _add_span_mock(instance).called def test_root_span_id_matches_expected(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) expected = self._expected_root_span_id(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["id"] == expected def test_root_span_has_no_parent(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["parent_span_id"] is None def test_trace_name_is_workflow_trace(self): @@ -176,21 +186,21 @@ class TestWorkflowTraceWithoutMessageId: trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - trace_kwargs = instance.add_trace.call_args_list[0][0][0] + trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0] assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE def test_root_span_name_is_workflow_trace(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE def test_root_span_has_workflow_tag(self): trace_info = _make_workflow_trace_info(message_id=None) instance = self._run(trace_info) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert "workflow" in root_span_kwargs["tags"] def test_node_execution_spans_are_parented_to_root(self): @@ -213,8 +223,9 @@ class TestWorkflowTraceWithoutMessageId: instance = self._run(trace_info, node_executions=[node_exec]) # call_args_list[0] = root span, [1] = node execution span - assert instance.add_span.call_count == 2 - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + add_span = _add_span_mock(instance) + assert add_span.call_count == 2 + node_span_kwargs = add_span.call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] == expected_root_span_id def test_node_span_not_parented_to_workflow_app_log_id(self): @@ -239,7 +250,7 @@ class TestWorkflowTraceWithoutMessageId: instance = self._run(trace_info, node_executions=[node_exec]) old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id) - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] != old_parent_id def test_root_span_id_differs_from_trace_id(self): @@ -265,10 +276,10 @@ class TestWorkflowTraceWithMessageId: fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( - patch("core.ops.opik_trace.opik_trace.db") as mock_db, - patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch("dify_trace_opik.opik_trace.db") as mock_db, + patch("dify_trace_opik.opik_trace.sessionmaker"), patch( - "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + "dify_trace_opik.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", return_value=fake_repo, ), ): @@ -282,7 +293,7 @@ class TestWorkflowTraceWithMessageId: trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) instance = self._run(trace_info) - trace_kwargs = instance.add_trace.call_args_list[0][0][0] + trace_kwargs = _add_trace_mock(instance).call_args_list[0][0][0] assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE def test_root_span_uses_workflow_run_id_directly(self): @@ -291,7 +302,7 @@ class TestWorkflowTraceWithMessageId: instance = self._run(trace_info) expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) - root_span_kwargs = instance.add_span.call_args_list[0][0][0] + root_span_kwargs = _add_span_mock(instance).call_args_list[0][0][0] assert root_span_kwargs["id"] == expected_root_span_id def test_root_span_id_differs_from_no_message_id_case(self): @@ -325,5 +336,5 @@ class TestWorkflowTraceWithMessageId: instance = self._run(trace_info, node_executions=[node_exec]) - node_span_kwargs = instance.add_span.call_args_list[1][0][0] + node_span_kwargs = _add_span_mock(instance).call_args_list[1][0][0] assert node_span_kwargs["parent_span_id"] == expected_root_span_id diff --git a/api/providers/trace/trace-tencent/pyproject.toml b/api/providers/trace/trace-tencent/pyproject.toml new file mode 100644 index 0000000000..eab06fc708 --- /dev/null +++ b/api/providers/trace/trace-tencent/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-trace-tencent" +version = "0.0.1" +dependencies = [ + # versions inherited from parent + "opentelemetry-api", + "opentelemetry-exporter-otlp-proto-grpc", + "opentelemetry-sdk", + "opentelemetry-semantic-conventions", +] +description = "Dify ops tracing provider (Tencent APM)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/chroma/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/chroma/__init__.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/__init__.py diff --git a/api/core/ops/tencent_trace/client.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py similarity index 100% rename from api/core/ops/tencent_trace/client.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/client.py diff --git a/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py new file mode 100644 index 0000000000..398e6c55a8 --- /dev/null +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/config.py @@ -0,0 +1,30 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig + + +class TencentConfig(BaseTracingConfig): + """ + Tencent APM tracing config + """ + + token: str + endpoint: str + service_name: str + + @field_validator("token") + @classmethod + def token_validator(cls, v, info: ValidationInfo): + if not v or v.strip() == "": + raise ValueError("Token cannot be empty") + return v + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + return cls.validate_endpoint_url(v, "https://apm.tencentcloudapi.com") + + @field_validator("service_name") + @classmethod + def service_name_validator(cls, v, info: ValidationInfo): + return cls.validate_project_field(v, "dify_app") diff --git a/api/core/ops/tencent_trace/entities/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py similarity index 100% rename from api/core/ops/tencent_trace/entities/__init__.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/__init__.py diff --git a/api/core/ops/tencent_trace/entities/semconv.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py similarity index 100% rename from api/core/ops/tencent_trace/entities/semconv.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/semconv.py diff --git a/api/core/ops/tencent_trace/entities/tencent_trace_entity.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py similarity index 100% rename from api/core/ops/tencent_trace/entities/tencent_trace_entity.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/entities/tencent_trace_entity.py diff --git a/api/core/rag/datasource/vdb/couchbase/__init__.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed similarity index 100% rename from api/core/rag/datasource/vdb/couchbase/__init__.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/py.typed diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py similarity index 98% rename from api/core/ops/tencent_trace/span_builder.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py index f79095d966..763a85ffd7 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/span_builder.py @@ -6,8 +6,6 @@ import json import logging from datetime import datetime -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -16,7 +14,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.entities.semconv import ( +from core.rag.models.document import Document +from dify_trace_tencent.entities.semconv import ( GEN_AI_COMPLETION, GEN_AI_FRAMEWORK, GEN_AI_IS_ENTRY, @@ -40,9 +39,10 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.utils import TencentTraceUtils -from core.rag.models.document import Document +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.utils import TencentTraceUtils +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py similarity index 94% rename from api/core/ops/tencent_trace/tencent_trace.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py index 84f54d8a5a..a8c480e4a5 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/providers/trace/trace-tencent/src/dify_trace_tencent/tencent_trace.py @@ -1,18 +1,12 @@ -""" -Tencent APM tracing implementation with separated concerns -""" +"""Tencent APM tracing with idempotent client cleanup.""" +import inspect import logging -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from graphon.nodes import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -23,12 +17,17 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.client import TencentTraceClient -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData -from core.ops.tencent_trace.span_builder import TencentSpanBuilder -from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from dify_trace_tencent.client import TencentTraceClient +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.entities.tencent_trace_entity import SpanData +from dify_trace_tencent.span_builder import TencentSpanBuilder +from dify_trace_tencent.utils import TencentTraceUtils from extensions.ext_database import db +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.nodes import BuiltinNodeTypes from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -38,10 +37,18 @@ class TencentDataTrace(BaseTraceInstance): """ Tencent APM trace implementation with single responsibility principle. Acts as a coordinator that delegates specific tasks to specialized classes. + + The instance owns a long-lived ``TencentTraceClient``. Cleanup may happen + explicitly in tests or implicitly during garbage collection, so shutdown + must be safe to call multiple times. """ + trace_client: TencentTraceClient + _closed: bool + def __init__(self, tencent_config: TencentConfig): super().__init__(tencent_config) + self._closed = False self.trace_client = TencentTraceClient( service_name=tencent_config.service_name, endpoint=tencent_config.endpoint, @@ -513,10 +520,25 @@ class TencentDataTrace(BaseTraceInstance): except Exception: logger.debug("[Tencent APM] Failed to record message trace duration") - def __del__(self): - """Ensure proper cleanup on garbage collection.""" + def close(self) -> None: + """Synchronously and idempotently shutdown the underlying trace client.""" + if getattr(self, "_closed", False): + return + + self._closed = True + trace_client = getattr(self, "trace_client", None) + if trace_client is None: + return + try: - if hasattr(self, "trace_client"): - self.trace_client.shutdown() + shutdown_result = trace_client.shutdown() + if inspect.isawaitable(shutdown_result): + close_awaitable = getattr(shutdown_result, "close", None) + if callable(close_awaitable): + close_awaitable() except Exception: logger.exception("[Tencent APM] Failed to shutdown trace client during cleanup") + + def __del__(self): + """Ensure best-effort cleanup on garbage collection without retrying shutdown.""" + self.close() diff --git a/api/core/ops/tencent_trace/utils.py b/api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py similarity index 100% rename from api/core/ops/tencent_trace/utils.py rename to api/providers/trace/trace-tencent/src/dify_trace_tencent/utils.py diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py similarity index 85% rename from api/tests/unit_tests/core/ops/tencent_trace/test_client.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py index 870c18e53e..3cd918f408 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_client.py @@ -5,15 +5,15 @@ from __future__ import annotations import sys import types from types import SimpleNamespace +from typing import Any, TypedDict, cast from unittest.mock import MagicMock import pytest +from dify_trace_tencent import client as client_module +from dify_trace_tencent.client import TencentTraceClient, _get_opentelemetry_sdk_version +from dify_trace_tencent.entities.tencent_trace_entity import SpanData from opentelemetry.sdk.trace import Event -from opentelemetry.trace import Status, StatusCode - -from core.ops.tencent_trace import client as client_module -from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version -from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData +from opentelemetry.trace import SpanContext, Status, StatusCode, TraceFlags metric_reader_instances: list[DummyMetricReader] = [] meter_provider_instances: list[DummyMeterProvider] = [] @@ -81,6 +81,16 @@ class DummyJsonMetricExporterNoTemporality: self.kwargs = kwargs +class PatchedCoreComponents(TypedDict): + span_exporter: MagicMock + span_processor: MagicMock + tracer: MagicMock + span: MagicMock + tracer_provider: MagicMock + logger: MagicMock + trace_api: Any + + def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None: """Drop fake metric modules into sys.modules so the client imports resolve.""" @@ -119,7 +129,7 @@ def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.fixture(autouse=True) -def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: +def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> PatchedCoreComponents: span_exporter = MagicMock(name="span_exporter") monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter)) @@ -169,6 +179,15 @@ def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: } +def _make_span_context(trace_id: int = 1, span_id: int = 2) -> SpanContext: + return SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + + def _build_client() -> TencentTraceClient: return TencentTraceClient( service_name="service", @@ -209,7 +228,7 @@ def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[st def test_resolve_grpc_target_handles_errors() -> None: - assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317) + assert TencentTraceClient._resolve_grpc_target(cast(str, 123)) == ("localhost:4317", True, "localhost", 4317) @pytest.mark.parametrize( @@ -249,7 +268,7 @@ def test_record_methods_skip_when_histogram_missing() -> None: client.record_trace_duration(0.5) -def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None: +def test_record_llm_duration_handles_exceptions(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() client.hist_llm_duration = MagicMock(name="hist_llm_duration") client.hist_llm_duration.record.side_effect = RuntimeError("boom") @@ -259,10 +278,11 @@ def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, logger.debug.assert_called() -def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_sets_attributes(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + ctx = _make_span_context(span_id=2) + span.get_span_context.return_value = ctx data = SpanData( trace_id=1, @@ -281,14 +301,15 @@ def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, span.add_event.assert_called_once() span.set_status.assert_called_once() span.end.assert_called_once_with(end_time=20) - assert client.span_contexts[2] == "ctx" + assert client.span_contexts[2] == ctx -def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_uses_parent_context(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() - client.span_contexts[10] = "existing" + existing_context = _make_span_context(span_id=10) + client.span_contexts[10] = existing_context span = patch_core_components["span"] - span.get_span_context.return_value = "child" + span.get_span_context.return_value = _make_span_context(span_id=11) data = SpanData( trace_id=1, @@ -303,14 +324,14 @@ def test_create_and_export_span_uses_parent_context(patch_core_components: dict[ client._create_and_export_span(data) trace_api = patch_core_components["trace_api"] - trace_api.NonRecordingSpan.assert_called_once_with("existing") + trace_api.NonRecordingSpan.assert_called_once_with(existing_context) trace_api.set_span_in_context.assert_called_once() -def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_exception_logs_error(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + span.get_span_context.return_value = _make_span_context(span_id=2) client.tracer.start_span.side_effect = RuntimeError("boom") client._create_and_export_span( @@ -386,7 +407,7 @@ def test_get_project_url() -> None: assert client.get_project_url() == "https://console.cloud.tencent.com/apm" -def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None: +def test_shutdown_flushes_all_components(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span_processor = patch_core_components["span_processor"] tracer_provider = patch_core_components["tracer_provider"] @@ -402,10 +423,11 @@ def test_shutdown_flushes_all_components(patch_core_components: dict[str, object metric_reader.shutdown.assert_called_once() -def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None: +def test_shutdown_logs_when_meter_provider_fails(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() meter_provider = meter_provider_instances[-1] meter_provider.shutdown.side_effect = RuntimeError("boom") + assert client.metric_reader is not None client.metric_reader.shutdown.side_effect = RuntimeError("boom") client.shutdown() @@ -434,7 +456,7 @@ def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: p assert client.metric_reader is None -def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None: +def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: PatchedCoreComponents) -> None: client = _build_client() monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom"))) @@ -455,10 +477,10 @@ def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_com logger.exception.assert_called_once() -def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None: +def test_create_and_export_span_converts_attribute_types(patch_core_components: PatchedCoreComponents) -> None: client = _build_client() span = patch_core_components["span"] - span.get_span_context.return_value = "ctx" + span.get_span_context.return_value = _make_span_context(span_id=2) data = SpanData.model_construct( trace_id=1, @@ -486,7 +508,7 @@ def test_record_llm_duration_converts_attributes() -> None: hist_mock = MagicMock(name="hist_llm_duration") client.hist_llm_duration = hist_mock - client.record_llm_duration(0.3, {"foo": object(), "bar": 2}) + client.record_llm_duration(0.3, cast(dict[str, str], {"foo": object(), "bar": 2})) _, attrs = hist_mock.record.call_args.args assert isinstance(attrs["foo"], str) assert attrs["bar"] == 2 @@ -497,7 +519,7 @@ def test_record_trace_duration_converts_attributes() -> None: hist_mock = MagicMock(name="hist_trace_duration") client.hist_trace_duration = hist_mock - client.record_trace_duration(1.0, {"meta": object(), "ok": True}) + client.record_trace_duration(1.0, cast(dict[str, str], {"meta": object(), "ok": True})) _, attrs = hist_mock.record.call_args.args assert isinstance(attrs["meta"], str) assert attrs["ok"] is True @@ -513,7 +535,7 @@ def test_record_trace_duration_converts_attributes() -> None: ], ) def test_record_methods_handle_exceptions( - method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object] + method: str, attr_name: str, args: tuple[object, ...], patch_core_components: PatchedCoreComponents ) -> None: client = _build_client() hist_mock = MagicMock(name=attr_name) @@ -528,35 +550,38 @@ def test_record_methods_handle_exceptions( def test_metrics_initializes_grpc_metric_exporter() -> None: client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyGrpcMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter) + assert isinstance(exporter, DummyGrpcMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317" - assert metric_reader.exporter.kwargs["insecure"] is False - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert exporter.kwargs["endpoint"] == "trace.example.com:4317" + assert exporter.kwargs["insecure"] is False + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyHttpMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) + assert isinstance(exporter, DummyHttpMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert exporter.kwargs["endpoint"] == client.endpoint + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") client = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyJsonMetricExporter, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyJsonMetricExporter) + assert isinstance(exporter, DummyJsonMetricExporter) assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 - assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint - assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" - assert "preferred_temporality" in metric_reader.exporter.kwargs + assert exporter.kwargs["endpoint"] == client.endpoint + assert cast(dict[str, dict[str, str]], exporter.kwargs)["headers"]["authorization"] == "Bearer token" + assert "preferred_temporality" in exporter.kwargs def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None: @@ -565,9 +590,10 @@ def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkey monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality) _ = _build_client() metric_reader = metric_reader_instances[-1] + exporter = cast(DummyJsonMetricExporterNoTemporality, metric_reader.exporter) - assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality) - assert "preferred_temporality" not in metric_reader.exporter.kwargs + assert isinstance(exporter, DummyJsonMetricExporterNoTemporality) + assert "preferred_temporality" not in exporter.kwargs def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py similarity index 89% rename from api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py index 696f859b6f..e850a801f3 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_span_builder.py @@ -1,17 +1,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch -from graphon.entities import WorkflowNodeExecution -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from opentelemetry.trace import StatusCode - -from core.ops.entities.trace_entity import ( - DatasetRetrievalTraceInfo, - MessageTraceInfo, - ToolTraceInfo, - WorkflowTraceInfo, -) -from core.ops.tencent_trace.entities.semconv import ( +from dify_trace_tencent.entities.semconv import ( GEN_AI_IS_ENTRY, GEN_AI_IS_STREAMING_REQUEST, GEN_AI_MODEL_NAME, @@ -25,13 +15,23 @@ from core.ops.tencent_trace.entities.semconv import ( TOOL_PARAMETERS, GenAISpanKind, ) -from core.ops.tencent_trace.span_builder import TencentSpanBuilder +from dify_trace_tencent.span_builder import TencentSpanBuilder +from opentelemetry.trace import StatusCode + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + MessageTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) from core.rag.models.document import Document +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class TestTencentSpanBuilder: def test_get_time_nanoseconds(self): - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: mock_convert.return_value = 123456789 dt = datetime.now() result = TencentSpanBuilder._get_time_nanoseconds(dt) @@ -48,7 +48,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {"answer": "world"} trace_info.metadata = {"conversation_id": "conv_id"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -70,7 +70,7 @@ class TestTencentSpanBuilder: trace_info.workflow_run_outputs = {} trace_info.metadata = {} # No conversation_id - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 1 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") @@ -98,7 +98,7 @@ class TestTencentSpanBuilder: } node_execution.outputs = {"text": "world"} - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -123,7 +123,7 @@ class TestTencentSpanBuilder: "usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}, } - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 456 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) @@ -142,7 +142,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {"conversation_id": "conv_id"} trace_info.is_streaming_request = True - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -162,7 +162,7 @@ class TestTencentSpanBuilder: trace_info.metadata = {} trace_info.is_streaming_request = False - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 789 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") @@ -182,7 +182,7 @@ class TestTencentSpanBuilder: trace_info.tool_inputs = {"i": 2} trace_info.tool_outputs = "result" - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 101 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1) @@ -204,7 +204,7 @@ class TestTencentSpanBuilder: ) trace_info.documents = [doc] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -222,7 +222,7 @@ class TestTencentSpanBuilder: trace_info.end_time = datetime.now() trace_info.documents = [] - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 202 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) @@ -264,7 +264,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -286,7 +286,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 303 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) @@ -307,7 +307,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -329,7 +329,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 404 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) @@ -350,7 +350,7 @@ class TestTencentSpanBuilder: node_execution.created_at = datetime.now() node_execution.finished_at = datetime.now() - with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + with patch("dify_trace_tencent.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: mock_convert_id.return_value = 505 with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py similarity index 86% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py index f67abba807..54524b09ca 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py @@ -1,11 +1,12 @@ +import gc import logging -from unittest.mock import MagicMock, patch +import warnings +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes +from dify_trace_tencent.config import TencentConfig +from dify_trace_tencent.tencent_trace import TencentDataTrace -from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -15,7 +16,8 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from core.ops.tencent_trace.tencent_trace import TencentDataTrace +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) @@ -28,19 +30,19 @@ def tencent_config(): @pytest.fixture def mock_trace_client(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceClient") as mock: yield mock @pytest.fixture def mock_span_builder(): - with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentSpanBuilder") as mock: yield mock @pytest.fixture def mock_trace_utils(): - with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock: + with patch("dify_trace_tencent.tencent_trace.TencentTraceUtils") as mock: yield mock @@ -198,9 +200,9 @@ class TestTencentDataTrace: trace_info.workflow_run_id = "run-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.workflow_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace") @@ -230,9 +232,9 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.message_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace") @@ -262,9 +264,9 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.tool_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace") @@ -294,22 +296,22 @@ class TestTencentDataTrace: trace_info.message_id = "msg-id" with patch( - "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") ): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.dataset_retrieval_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace") def test_suggested_question_trace(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace") def test_suggested_question_trace_exception(self, tencent_data_trace): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace.suggested_question_trace(trace_info) mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace") @@ -342,7 +344,7 @@ class TestTencentDataTrace: with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) # The exception should be caught by the outer handler since convert_to_span_id is called first mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -351,7 +353,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: tencent_data_trace._process_workflow_nodes(trace_info, 123) mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") @@ -381,7 +383,7 @@ class TestTencentDataTrace: node.id = "n1" mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456) assert result is None mock_log.assert_called_once() @@ -403,15 +405,13 @@ class TestTencentDataTrace: mock_executions = [MagicMock()] - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.side_effect = [app, account, tenant_join] - with patch( - "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" - ) as mock_repo: + with patch("dify_trace_tencent.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository") as mock_repo: mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) @@ -423,7 +423,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {} - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -432,14 +432,14 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {"app_id": "app-1"} - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() # Ensure init_app is mocked mock_db.engine = "engine" - with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + with patch("dify_trace_tencent.tencent_trace.Session") as mock_session_ctx: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.return_value = None - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: results = tencent_data_trace._get_workflow_node_executions(trace_info) assert results == [] mock_log.assert_called_once() @@ -449,8 +449,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "tenant-1" trace_info.metadata = {"user_id": "user-1"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")): - with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")): + with patch("dify_trace_tencent.tencent_trace.db") as mock_db: mock_db.init_app = MagicMock() mock_db.engine = MagicMock() @@ -476,8 +476,8 @@ class TestTencentDataTrace: trace_info.tenant_id = "t" trace_info.metadata = {"user_id": "u"} - with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")): - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")): + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: user_id = tencent_data_trace._get_user_id(trace_info) assert user_id == "unknown" mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID") @@ -519,7 +519,7 @@ class TestTencentDataTrace: node.process_data = None node.outputs = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_llm_metrics(node) # Should not crash @@ -557,7 +557,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.metadata = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_llm_metrics(trace_info) # Should not crash @@ -609,7 +609,7 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_workflow_trace_duration(trace_info) def test_record_message_trace_duration(self, tencent_data_trace): @@ -631,16 +631,41 @@ class TestTencentDataTrace: trace_info = MagicMock(spec=MessageTraceInfo) trace_info.start_time = None - with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: tencent_data_trace._record_message_trace_duration(trace_info) - def test_del(self, tencent_data_trace): + def test_close(self, tencent_data_trace): client = tencent_data_trace.trace_client - tencent_data_trace.__del__() + tencent_data_trace.close() client.shutdown.assert_called_once() - def test_del_exception(self, tencent_data_trace): + def test_close_is_idempotent(self, tencent_data_trace): + client = tencent_data_trace.trace_client + + tencent_data_trace.close() + tencent_data_trace.close() + + client.shutdown.assert_called_once() + + def test_close_exception(self, tencent_data_trace): tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") - with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.__del__() + with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.close() mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") + + def test_close_handles_async_shutdown_mock(self, tencent_data_trace): + shutdown = AsyncMock() + tencent_data_trace.trace_client.shutdown = shutdown + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + tencent_data_trace.close() + gc.collect() + + shutdown.assert_called_once() + assert not [ + warning + for warning in caught + if issubclass(warning.category, RuntimeWarning) + and "AsyncMockMixin._execute_mock_call" in str(warning.message) + ] diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py similarity index 88% rename from api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py rename to api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py index ef28d18e20..63c6d680d7 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace_utils.py @@ -8,10 +8,9 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from dify_trace_tencent.utils import TencentTraceUtils from opentelemetry.trace import Link, TraceFlags -from core.ops.tencent_trace.utils import TencentTraceUtils - def test_convert_to_trace_id_with_valid_uuid() -> None: uuid_str = "12345678-1234-5678-1234-567812345678" @@ -20,7 +19,7 @@ def test_convert_to_trace_id_with_valid_uuid() -> None: def test_convert_to_trace_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int uuid4_mock.assert_called_once() @@ -45,7 +44,7 @@ def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None: def test_convert_to_span_id_uses_uuid4_when_none() -> None: expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: span_id = TencentTraceUtils.convert_to_span_id(None, "workflow") assert isinstance(span_id, int) uuid4_mock.assert_called_once() @@ -58,7 +57,7 @@ def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None: def test_generate_span_id_skips_invalid_span_id() -> None: with patch( - "core.ops.tencent_trace.utils.random.getrandbits", + "dify_trace_tencent.utils.random.getrandbits", side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42], ) as bits_mock: assert TencentTraceUtils.generate_span_id() == 42 @@ -75,7 +74,7 @@ def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None: fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) expected = int(fixed.timestamp() * 1e9) - with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock: + with patch("dify_trace_tencent.utils.datetime") as datetime_mock: datetime_mock.now.return_value = fixed assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected datetime_mock.now.assert_called_once() @@ -100,7 +99,7 @@ def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: i @pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None]) def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None: fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") - with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: + with patch("dify_trace_tencent.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type] assert link.context.trace_id == fallback_uuid.int uuid4_mock.assert_called_once() diff --git a/api/providers/trace/trace-weave/pyproject.toml b/api/providers/trace/trace-weave/pyproject.toml new file mode 100644 index 0000000000..ba449f2a93 --- /dev/null +++ b/api/providers/trace/trace-weave/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "dify-trace-weave" +version = "0.0.1" +dependencies = [ + "weave>=0.52.36", +] +description = "Dify ops tracing provider (Weave)." + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/elasticsearch/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/elasticsearch/__init__.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/__init__.py diff --git a/api/providers/trace/trace-weave/src/dify_trace_weave/config.py b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py new file mode 100644 index 0000000000..5942bd57fe --- /dev/null +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/config.py @@ -0,0 +1,29 @@ +from pydantic import ValidationInfo, field_validator + +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.utils import validate_url + + +class WeaveConfig(BaseTracingConfig): + """ + Model class for Weave tracing config. + """ + + api_key: str + entity: str | None = None + project: str + endpoint: str = "https://trace.wandb.ai" + host: str | None = None + + @field_validator("endpoint") + @classmethod + def endpoint_validator(cls, v, info: ValidationInfo): + # Weave only allows HTTPS for endpoint + return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",)) + + @field_validator("host") + @classmethod + def host_validator(cls, v, info: ValidationInfo): + if v is not None and v.strip() != "": + return validate_url(v, v, allowed_schemes=("https", "http")) + return v diff --git a/api/core/rag/datasource/vdb/hologres/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/hologres/__init__.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/entities/__init__.py diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py similarity index 100% rename from api/core/ops/weave_trace/entities/weave_trace_entity.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/entities/weave_trace_entity.py diff --git a/api/core/rag/datasource/vdb/huawei/__init__.py b/api/providers/trace/trace-weave/src/dify_trace_weave/py.typed similarity index 100% rename from api/core/rag/datasource/vdb/huawei/__init__.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/py.typed diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py similarity index 99% rename from api/core/ops/weave_trace/weave_trace.py rename to api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py index 8d9ba4694d..4292cbf0f1 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/providers/trace/trace-weave/src/dify_trace_weave/weave_trace.py @@ -6,7 +6,6 @@ from typing import Any, cast import wandb import weave -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from sqlalchemy.orm import sessionmaker from weave.trace_server.trace_server_interface import ( CallEndReq, @@ -18,7 +17,6 @@ from weave.trace_server.trace_server_interface import ( ) from core.ops.base_trace_instance import BaseTraceInstance -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( BaseTraceInfo, DatasetRetrievalTraceInfo, @@ -30,9 +28,11 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel from extensions.ext_database import db +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) diff --git a/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py new file mode 100644 index 0000000000..377c768198 --- /dev/null +++ b/api/providers/trace/trace-weave/tests/unit_tests/test_config_entity.py @@ -0,0 +1,61 @@ +import pytest +from dify_trace_weave.config import WeaveConfig +from pydantic import ValidationError + + +class TestWeaveConfig: + """Test cases for WeaveConfig""" + + def test_valid_config(self): + """Test valid Weave configuration""" + config = WeaveConfig( + api_key="test_key", + entity="test_entity", + project="test_project", + endpoint="https://custom.wandb.ai", + host="https://custom.host.com", + ) + assert config.api_key == "test_key" + assert config.entity == "test_entity" + assert config.project == "test_project" + assert config.endpoint == "https://custom.wandb.ai" + assert config.host == "https://custom.host.com" + + def test_default_values(self): + """Test default values are set correctly""" + config = WeaveConfig(api_key="key", project="project") + assert config.entity is None + assert config.endpoint == "https://trace.wandb.ai" + assert config.host is None + + def test_missing_required_fields(self): + """Test that required fields are enforced""" + with pytest.raises(ValidationError): + WeaveConfig.model_validate({}) + + with pytest.raises(ValidationError): + WeaveConfig.model_validate({"api_key": "key"}) + + with pytest.raises(ValidationError): + WeaveConfig.model_validate({"project": "project"}) + + def test_endpoint_validation_https_only(self): + """Test endpoint validation only allows HTTPS""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") + + def test_host_validation_optional(self): + """Test host validation is optional but validates when provided""" + config = WeaveConfig(api_key="key", project="project", host=None) + assert config.host is None + + config = WeaveConfig(api_key="key", project="project", host="") + assert config.host == "" + + config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") + assert config.host == "https://valid.host.com" + + def test_host_validation_invalid_scheme(self): + """Test host validation rejects invalid schemes when provided""" + with pytest.raises(ValidationError, match="URL scheme must be one of"): + WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py similarity index 97% rename from api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py rename to api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py index 5014f40afc..6028d0c550 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/providers/trace/trace-weave/tests/unit_tests/weave_trace/test_weave_trace.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for core.ops.weave_trace.weave_trace module.""" +"""Comprehensive tests for dify_trace_weave.weave_trace module.""" from __future__ import annotations @@ -7,10 +7,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from dify_trace_weave.config import WeaveConfig +from dify_trace_weave.entities.weave_trace_entity import WeaveTraceModel +from dify_trace_weave.weave_trace import WeaveDataTrace from weave.trace_server.trace_server_interface import TraceStatus -from core.ops.entities.config_entity import WeaveConfig from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, GenerateNameTraceInfo, @@ -21,8 +22,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.ops.weave_trace.weave_trace import WeaveDataTrace +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -191,14 +191,14 @@ def _make_node(**overrides): @pytest.fixture def mock_wandb(): - with patch("core.ops.weave_trace.weave_trace.wandb") as mock: + with patch("dify_trace_weave.weave_trace.wandb") as mock: mock.login.return_value = True yield mock @pytest.fixture def mock_weave(): - with patch("core.ops.weave_trace.weave_trace.weave") as mock: + with patch("dify_trace_weave.weave_trace.weave") as mock: client = MagicMock() client.entity = "my-entity" client.project = "my-project" @@ -307,7 +307,7 @@ class TestGetProjectUrl: monkeypatch.setattr(trace_instance, "entity", None) monkeypatch.setattr(trace_instance, "project_name", None) # Force an error by making string formatting fail - with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger: + with patch("dify_trace_weave.weave_trace.logger") as mock_logger: # Simulate exception via property original_entity = trace_instance.entity trace_instance.entity = None @@ -594,9 +594,9 @@ class TestWorkflowTrace: mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo - monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) return repo def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch): @@ -703,8 +703,8 @@ class TestWorkflowTrace: def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch): """Raises ValueError when app_id is missing from metadata.""" - monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr("dify_trace_weave.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", MagicMock(engine="engine")) trace_info = _make_workflow_trace_info( message_id=None, @@ -802,7 +802,7 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.get", + "dify_trace_weave.weave_trace.db.session.get", lambda model, pk: None, ) @@ -824,7 +824,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -846,7 +846,7 @@ class TestMessageTrace: mock_db = MagicMock() mock_db.session.get.return_value = end_user - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -866,7 +866,7 @@ class TestMessageTrace: """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -884,7 +884,7 @@ class TestMessageTrace: """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() @@ -899,7 +899,7 @@ class TestMessageTrace: """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() mock_db.session.get.return_value = None - monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + monkeypatch.setattr("dify_trace_weave.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() trace_instance.finish_call = MagicMock() diff --git a/api/providers/vdb/README.md b/api/providers/vdb/README.md new file mode 100644 index 0000000000..b5b4197f63 --- /dev/null +++ b/api/providers/vdb/README.md @@ -0,0 +1,58 @@ +# VDB providers + +This directory contains all VDB providers. + +## Architecture +1. **Core** (`api/core/rag/datasource/vdb/`) defines the contracts and loads plugins. +2. **Each provider** (`api/providers/vdb//`) implements those contracts and registers an entry point. +3. At runtime, **`importlib.metadata.entry_points`** resolves the backend name (e.g. `pgvector`) to a factory class. The registry caches loaded classes (see `vector_backend_registry.py`). + +### Interfaces + +| Piece | Role | +|--------|----------| +| `AbstractVectorFactory` | You subclass this. Implement `init_vector(dataset, attributes, embeddings) -> BaseVector`. Optionally use `gen_index_struct_dict()` for new datasets. | +| `BaseVector` | Your store class subclasses this: `create`, `add_texts`, `search_by_vector`, `delete`, etc. | +| `VectorType` | `StrEnum` of supported backend **string ids**. Add a member when you introduce a new backend that should be selectable like existing ones. | +| Discovery | Loads `dify.vector_backends` entry points and caches `get_vector_factory_class(vector_type)`. | + +The high-level caller is `Vector` in `vector_factory.py`: it reads the configured or dataset-specific vector type, calls `get_vector_factory_class`, instantiates the factory, and uses the returned `BaseVector` implementation. + +### Entry point name must match the vector type string + +Entry points are registered under the group **`dify.vector_backends`**. The **entry point name** (left-hand side) must be exactly the string used as `vector_type` everywhere else—typically the **`VectorType` enum value** (e.g. `PGVECTOR = "pgvector"` → entry point name `pgvector`; `TIDB_ON_QDRANT = "tidb_on_qdrant"` → `tidb_on_qdrant`). + +In `pyproject.toml`: + +```toml +[project.entry-points."dify.vector_backends"] +pgvector = "dify_vdb_pgvector.pgvector:PGVectorFactory" +``` + +The value is **`module:attribute`**: a importable module path and the class implementing `AbstractVectorFactory`. + +### How registration works + +1. On first use, `get_vector_factory_class(vector_type)` looks up `vector_type` in a process cache. +2. If missing, it scans **`entry_points().select(group="dify.vector_backends")`** for an entry whose **`name` equals `vector_type`**. +3. It loads that entry (`ep.load()`), which must return the **factory class** (not an instance). +4. There is an optional internal map `_BUILTIN_VECTOR_FACTORY_TARGETS` for non-distribution builtins; **normal VDB plugins use entry points only**. + +After you change a provider’s `pyproject.toml` (entry points or dependencies), run **`uv sync`** in `api/` so the installed environment’s dist-info matches the project metadata. + +### Package layout (VDB) + +Each backend usually follows: + +- `api/providers/vdb//pyproject.toml` — project name `dify-vdb-`, dependencies, entry points. +- `api/providers/vdb//src/dify_vdb_/` — implementation (e.g. `PGVector`, `PGVectorFactory`). + +See `vdb/pgvector/` as a reference implementation. + +### Wiring a new backend into the API workspace + +The API uses a **uv workspace** (`api/pyproject.toml`): + +1. **`[tool.uv.workspace]`** — `members = ["providers/vdb/*"]` already includes every subdirectory under `vdb/`; new folders there are workspace members. +2. **`[tool.uv.sources]`** — add a line for your package: `dify-vdb-mine = { workspace = true }`. +3. **`[project.optional-dependencies]`** — add a group such as `vdb-mine = ["dify-vdb-mine"]`, and list `dify-vdb-mine` under `vdb-all` if it should install with the default bundle. \ No newline at end of file diff --git a/api/providers/vdb/conftest.py b/api/providers/vdb/conftest.py new file mode 100644 index 0000000000..c4b1cdef29 --- /dev/null +++ b/api/providers/vdb/conftest.py @@ -0,0 +1,22 @@ +from unittest.mock import MagicMock + +import pytest + +from extensions import ext_redis + + +@pytest.fixture(autouse=True) +def _init_mock_redis(): + """Ensure redis_client has a backing client so __getattr__ never raises.""" + if ext_redis.redis_client._client is None: + ext_redis.redis_client.initialize(MagicMock()) + + +@pytest.fixture +def setup_mock_redis(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(ext_redis.redis_client, "get", MagicMock(return_value=None)) + monkeypatch.setattr(ext_redis.redis_client, "set", MagicMock(return_value=None)) + mock_redis_lock = MagicMock() + mock_redis_lock.__enter__ = MagicMock() + mock_redis_lock.__exit__ = MagicMock() + monkeypatch.setattr(ext_redis.redis_client, "lock", mock_redis_lock) diff --git a/api/providers/vdb/vdb-alibabacloud-mysql/pyproject.toml b/api/providers/vdb/vdb-alibabacloud-mysql/pyproject.toml new file mode 100644 index 0000000000..bbc0e06ffa --- /dev/null +++ b/api/providers/vdb/vdb-alibabacloud-mysql/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "dify-vdb-alibabacloud-mysql" +version = "0.0.1" +dependencies = [ + "mysql-connector-python>=9.3.0", +] +description = "Dify vector store backend (dify-vdb-alibabacloud-mysql)." + +[project.entry-points."dify.vector_backends"] +alibabacloud_mysql = "dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector:AlibabaCloudMySQLVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/iris/__init__.py b/api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/iris/__init__.py rename to api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/__init__.py diff --git a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py b/api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/alibabacloud_mysql_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py rename to api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/alibabacloud_mysql_vector.py index 6e76827a42..37ffd11063 100644 --- a/api/core/rag/datasource/vdb/alibabacloud_mysql/alibabacloud_mysql_vector.py +++ b/api/providers/vdb/vdb-alibabacloud-mysql/src/dify_vdb_alibabacloud_mysql/alibabacloud_mysql_vector.py @@ -35,7 +35,7 @@ class AlibabaCloudMySQLVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values.get("host"): raise ValueError("config ALIBABACLOUD_MYSQL_HOST is required") if not values.get("port"): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py b/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py similarity index 94% rename from api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py rename to api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py index e063a49f22..a907f918c3 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_factory.py +++ b/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py @@ -1,10 +1,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch +import dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module import pytest - -import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module -from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory +from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory def test_validate_distance_function_accepts_supported_values(): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_vector.py b/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py similarity index 87% rename from api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_vector.py rename to api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py index 8ccd739e64..54eeb78ca9 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/alibabacloud_mysql/test_alibabacloud_mysql_vector.py +++ b/api/providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py @@ -3,11 +3,11 @@ import unittest from unittest.mock import MagicMock, patch import pytest - -from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import ( +from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import ( AlibabaCloudMySQLVector, AlibabaCloudMySQLVectorConfig, ) + from core.rag.models.document import Document try: @@ -49,9 +49,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): # Sample embeddings self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]] - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_init(self, mock_pool_class): """Test AlibabaCloudMySQLVector initialization.""" # Mock the connection pool @@ -76,10 +74,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert alibabacloud_mysql_vector.distance_function == "cosine" assert alibabacloud_mysql_vector.pool is not None - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) - @patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client") + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client") def test_create_collection(self, mock_redis, mock_pool_class): """Test collection creation.""" # Mock Redis operations @@ -110,9 +106,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes mock_redis.set.assert_called_once() - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_vector_support_check_success(self, mock_pool_class): """Test successful vector support check.""" # Mock the connection pool @@ -129,9 +123,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config) assert vector_store is not None - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_vector_support_check_failure(self, mock_pool_class): """Test vector support check failure.""" # Mock the connection pool @@ -149,9 +141,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "RDS MySQL Vector functions are not available" in str(context.value) - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_vector_support_check_function_error(self, mock_pool_class): """Test vector support check with function not found error.""" # Mock the connection pool @@ -170,10 +160,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "RDS MySQL Vector functions are not available" in str(context.value) - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) - @patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client") + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client") def test_create_documents(self, mock_redis, mock_pool_class): """Test creating documents with embeddings.""" # Setup mocks @@ -186,9 +174,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "doc1" in result assert "doc2" in result - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_add_texts(self, mock_pool_class): """Test adding texts to the vector store.""" # Mock the connection pool @@ -207,9 +193,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert len(result) == 2 mock_cursor.executemany.assert_called_once() - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_text_exists(self, mock_pool_class): """Test checking if text exists.""" # Mock the connection pool @@ -236,9 +220,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "SELECT id FROM" in last_call[0][0] assert last_call[0][1] == ("doc1",) - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_text_not_exists(self, mock_pool_class): """Test checking if text does not exist.""" # Mock the connection pool @@ -260,9 +242,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert not exists - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_get_by_ids(self, mock_pool_class): """Test getting documents by IDs.""" # Mock the connection pool @@ -288,9 +268,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert docs[0].page_content == "Test document 1" assert docs[1].page_content == "Test document 2" - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_get_by_ids_empty_list(self, mock_pool_class): """Test getting documents with empty ID list.""" # Mock the connection pool @@ -308,9 +286,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert len(docs) == 0 - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_delete_by_ids(self, mock_pool_class): """Test deleting documents by IDs.""" # Mock the connection pool @@ -334,9 +310,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "DELETE FROM" in delete_call[0][0] assert delete_call[0][1] == ["doc1", "doc2"] - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_delete_by_ids_empty_list(self, mock_pool_class): """Test deleting with empty ID list.""" # Mock the connection pool @@ -357,9 +331,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): delete_calls = [call for call in execute_calls if "DELETE" in str(call)] assert len(delete_calls) == 0 - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_delete_by_ids_table_not_exists(self, mock_pool_class): """Test deleting when table doesn't exist.""" # Mock the connection pool @@ -384,9 +356,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): # Should not raise an exception vector_store.delete_by_ids(["doc1"]) - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_delete_by_metadata_field(self, mock_pool_class): """Test deleting documents by metadata field.""" # Mock the connection pool @@ -410,9 +380,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0] assert delete_call[0][1] == ("$.document_id", "dataset1") - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_vector_cosine(self, mock_pool_class): """Test vector search with cosine distance.""" # Mock the connection pool @@ -437,9 +405,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9 assert docs[0].metadata["distance"] == 0.1 - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_vector_euclidean(self, mock_pool_class): """Test vector search with euclidean distance.""" config = AlibabaCloudMySQLVectorConfig( @@ -472,9 +438,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert len(docs) == 1 assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3 - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_vector_with_filter(self, mock_pool_class): """Test vector search with document ID filter.""" # Mock the connection pool @@ -499,9 +463,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): search_call = search_calls[0] assert "WHERE JSON_UNQUOTE" in search_call[0][0] - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_vector_with_score_threshold(self, mock_pool_class): """Test vector search with score threshold.""" # Mock the connection pool @@ -536,9 +498,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].page_content == "High similarity document" - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_vector_invalid_top_k(self, mock_pool_class): """Test vector search with invalid top_k.""" # Mock the connection pool @@ -560,9 +520,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): with pytest.raises(ValueError): vector_store.search_by_vector(query_vector, top_k="invalid") - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_full_text(self, mock_pool_class): """Test full-text search.""" # Mock the connection pool @@ -591,9 +549,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): assert docs[0].page_content == "This document contains machine learning content" assert docs[0].metadata["score"] == 1.5 - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_full_text_with_filter(self, mock_pool_class): """Test full-text search with document ID filter.""" # Mock the connection pool @@ -617,9 +573,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): search_call = search_calls[0] assert "AND JSON_UNQUOTE" in search_call[0][0] - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_search_by_full_text_invalid_top_k(self, mock_pool_class): """Test full-text search with invalid top_k.""" # Mock the connection pool @@ -640,9 +594,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): with pytest.raises(ValueError): vector_store.search_by_full_text("test", top_k="invalid") - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_delete_collection(self, mock_pool_class): """Test deleting the entire collection.""" # Mock the connection pool @@ -665,9 +617,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase): drop_call = drop_calls[0] assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0] - @patch( - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool" - ) + @patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool") def test_unsupported_distance_function(self, mock_pool_class): """Test that Pydantic validation rejects unsupported distance functions.""" # Test that creating config with unsupported distance function raises ValidationError diff --git a/api/providers/vdb/vdb-analyticdb/pyproject.toml b/api/providers/vdb/vdb-analyticdb/pyproject.toml new file mode 100644 index 0000000000..af5def3061 --- /dev/null +++ b/api/providers/vdb/vdb-analyticdb/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "dify-vdb-analyticdb" +version = "0.0.1" +dependencies = [ + "alibabacloud_gpdb20160503~=5.2.0", + "alibabacloud_tea_openapi~=0.4.3", + "clickhouse-connect~=0.15.0", +] +description = "Dify vector store backend (dify-vdb-analyticdb)." + +[project.entry-points."dify.vector_backends"] +analyticdb = "dify_vdb_analyticdb.analyticdb_vector:AnalyticdbVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/lindorm/__init__.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/lindorm/__init__.py rename to api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/__init__.py diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector.py similarity index 95% rename from api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py rename to api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector.py index 79cc5f0344..e56bb74ba3 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector.py @@ -2,16 +2,16 @@ import json from typing import Any from configs import dify_config -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( - AnalyticdbVectorOpenAPI, - AnalyticdbVectorOpenAPIConfig, -) -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document +from dify_vdb_analyticdb.analyticdb_vector_openapi import ( + AnalyticdbVectorOpenAPI, + AnalyticdbVectorOpenAPIConfig, +) +from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_openapi.py similarity index 99% rename from api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py rename to api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_openapi.py index 726ee8c050..f13d9c0817 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_openapi.py @@ -34,7 +34,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["access_key_id"]: raise ValueError("config ANALYTICDB_KEY_ID is required") if not values["access_key_secret"]: diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py similarity index 98% rename from api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py rename to api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py index 41c33a3ab1..11398efb58 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/providers/vdb/vdb-analyticdb/src/dify_vdb_analyticdb/analyticdb_vector_sql.py @@ -1,6 +1,6 @@ import json import uuid -from collections.abc import Iterator +from collections.abc import Generator # Added Generator from contextlib import contextmanager from typing import Any @@ -24,7 +24,7 @@ class AnalyticdbVectorBySqlConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config ANALYTICDB_HOST is required") if not values["port"]: @@ -75,7 +75,7 @@ class AnalyticdbVectorBySql: ) @contextmanager - def _get_cursor(self) -> Iterator[Any]: + def _get_cursor(self) -> Generator[Any, None, None]: # Changed from Iterator[Any] assert self.pool is not None, "Connection pool is not initialized" conn = self.pool.getconn() cur = conn.cursor() diff --git a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py b/api/providers/vdb/vdb-analyticdb/tests/integration_tests/test_analyticdb.py similarity index 79% rename from api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py rename to api/providers/vdb/vdb-analyticdb/tests/integration_tests/test_analyticdb.py index 0981523809..2bb413dcc1 100644 --- a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py +++ b/api/providers/vdb/vdb-analyticdb/tests/integration_tests/test_analyticdb.py @@ -1,9 +1,8 @@ -from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest +from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector +from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest class AnalyticdbVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector.py similarity index 93% rename from api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py rename to api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector.py index d4fa4b3e8e..d1d471761d 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector.py +++ b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector.py @@ -1,12 +1,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch +import dify_vdb_analyticdb.analyticdb_vector as analyticdb_module import pytest +from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory +from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig +from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig -import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module -from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig from core.rag.models.document import Document diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_openapi.py similarity index 98% rename from api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py rename to api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_openapi.py index 4f8653a926..d2d735ae3e 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_openapi.py +++ b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_openapi.py @@ -4,13 +4,13 @@ import types from types import SimpleNamespace from unittest.mock import MagicMock +import dify_vdb_analyticdb.analyticdb_vector_openapi as openapi_module import pytest - -import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import ( +from dify_vdb_analyticdb.analyticdb_vector_openapi import ( AnalyticdbVectorOpenAPI, AnalyticdbVectorOpenAPIConfig, ) + from core.rag.models.document import Document diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_sql.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py rename to api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_sql.py index f798ef8bd1..49a2ae72d0 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/analyticdb/test_analyticdb_vector_sql.py +++ b/api/providers/vdb/vdb-analyticdb/tests/unit_tests/test_analyticdb_vector_sql.py @@ -2,14 +2,14 @@ from contextlib import contextmanager from types import SimpleNamespace from unittest.mock import MagicMock +import dify_vdb_analyticdb.analyticdb_vector_sql as sql_module import psycopg2.errors import pytest - -import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module -from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import ( +from dify_vdb_analyticdb.analyticdb_vector_sql import ( AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig, ) + from core.rag.models.document import Document diff --git a/api/providers/vdb/vdb-baidu/pyproject.toml b/api/providers/vdb/vdb-baidu/pyproject.toml new file mode 100644 index 0000000000..bacff08793 --- /dev/null +++ b/api/providers/vdb/vdb-baidu/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "dify-vdb-baidu" +version = "0.0.1" +dependencies = [ + "pymochow==2.4.0", +] +description = "Dify vector store backend (dify-vdb-baidu)." + +[project.entry-points."dify.vector_backends"] +baidu = "dify_vdb_baidu.baidu_vector:BaiduVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/matrixone/__init__.py b/api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/matrixone/__init__.py rename to api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/__init__.py diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/baidu_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/baidu/baidu_vector.py rename to api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/baidu_vector.py index 99ab0d82f2..bdd5a42c87 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/providers/vdb/vdb-baidu/src/dify_vdb_baidu/baidu_vector.py @@ -59,7 +59,7 @@ class BaiduConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["endpoint"]: raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required") if not values["account"]: diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/providers/vdb/vdb-baidu/tests/integration_tests/conftest.py similarity index 100% rename from api/tests/integration_tests/vdb/__mock/baiduvectordb.py rename to api/providers/vdb/vdb-baidu/tests/integration_tests/conftest.py diff --git a/api/tests/integration_tests/vdb/baidu/test_baidu.py b/api/providers/vdb/vdb-baidu/tests/integration_tests/test_baidu.py similarity index 73% rename from api/tests/integration_tests/vdb/baidu/test_baidu.py rename to api/providers/vdb/vdb-baidu/tests/integration_tests/test_baidu.py index 716f88af67..2c1d0e3554 100644 --- a/api/tests/integration_tests/vdb/baidu/test_baidu.py +++ b/api/providers/vdb/vdb-baidu/tests/integration_tests/test_baidu.py @@ -1,10 +1,6 @@ -from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text +from dify_vdb_baidu.baidu_vector import BaiduConfig, BaiduVector -pytest_plugins = ( - "tests.integration_tests.vdb.test_vector_store", - "tests.integration_tests.vdb.__mock.baiduvectordb", -) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text class BaiduVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py b/api/providers/vdb/vdb-baidu/tests/unit_tests/test_baidu_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py rename to api/providers/vdb/vdb-baidu/tests/unit_tests/test_baidu_vector.py index 487d021697..851c09f47a 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/baidu/test_baidu_vector.py +++ b/api/providers/vdb/vdb-baidu/tests/unit_tests/test_baidu_vector.py @@ -124,7 +124,7 @@ def _build_fake_pymochow_modules(): def baidu_module(monkeypatch): for name, module in _build_fake_pymochow_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.baidu.baidu_vector as module + import dify_vdb_baidu.baidu_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-chroma/pyproject.toml b/api/providers/vdb/vdb-chroma/pyproject.toml new file mode 100644 index 0000000000..b37ee2a588 --- /dev/null +++ b/api/providers/vdb/vdb-chroma/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "dify-vdb-chroma" +version = "0.0.1" +dependencies = [ + "chromadb==0.5.20", +] +description = "Dify vector store backend (dify-vdb-chroma)." + +[project.entry-points."dify.vector_backends"] +chroma = "dify_vdb_chroma.chroma_vector:ChromaVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/milvus/__init__.py b/api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/milvus/__init__.py rename to api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/__init__.py diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/chroma_vector.py similarity index 100% rename from api/core/rag/datasource/vdb/chroma/chroma_vector.py rename to api/providers/vdb/vdb-chroma/src/dify_vdb_chroma/chroma_vector.py diff --git a/api/tests/integration_tests/vdb/chroma/test_chroma.py b/api/providers/vdb/vdb-chroma/tests/integration_tests/test_chroma.py similarity index 80% rename from api/tests/integration_tests/vdb/chroma/test_chroma.py rename to api/providers/vdb/vdb-chroma/tests/integration_tests/test_chroma.py index 52beba9979..87c259f3d0 100644 --- a/api/tests/integration_tests/vdb/chroma/test_chroma.py +++ b/api/providers/vdb/vdb-chroma/tests/integration_tests/test_chroma.py @@ -1,13 +1,11 @@ import chromadb +from dify_vdb_chroma.chroma_vector import ChromaConfig, ChromaVector -from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector -from tests.integration_tests.vdb.test_vector_store import ( +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, get_example_text, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class ChromaVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py b/api/providers/vdb/vdb-chroma/tests/unit_tests/test_chroma_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py rename to api/providers/vdb/vdb-chroma/tests/unit_tests/test_chroma_vector.py index 44427b7d87..b209c9df96 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/chroma/test_chroma_vector.py +++ b/api/providers/vdb/vdb-chroma/tests/unit_tests/test_chroma_vector.py @@ -47,7 +47,7 @@ def _build_fake_chroma_modules(): def chroma_module(monkeypatch): fake_chroma = _build_fake_chroma_modules() monkeypatch.setitem(sys.modules, "chromadb", fake_chroma) - import core.rag.datasource.vdb.chroma.chroma_vector as module + import dify_vdb_chroma.chroma_vector as module return importlib.reload(module) diff --git a/api/core/rag/datasource/vdb/clickzetta/README.md b/api/providers/vdb/vdb-clickzetta/README.md similarity index 99% rename from api/core/rag/datasource/vdb/clickzetta/README.md rename to api/providers/vdb/vdb-clickzetta/README.md index 969d4e40a0..faa76707ce 100644 --- a/api/core/rag/datasource/vdb/clickzetta/README.md +++ b/api/providers/vdb/vdb-clickzetta/README.md @@ -198,4 +198,4 @@ Clickzetta supports advanced full-text search with multiple analyzers: - [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search) - [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index) -- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference) +- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference) \ No newline at end of file diff --git a/api/providers/vdb/vdb-clickzetta/pyproject.toml b/api/providers/vdb/vdb-clickzetta/pyproject.toml new file mode 100644 index 0000000000..aea94fdb2a --- /dev/null +++ b/api/providers/vdb/vdb-clickzetta/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-clickzetta" +version = "0.0.1" + +dependencies = [ + "clickzetta-connector-python>=0.8.102", +] +description = "Dify vector store backend (dify-vdb-clickzetta)." + +[project.entry-points."dify.vector_backends"] +clickzetta = "dify_vdb_clickzetta.clickzetta_vector:ClickzettaVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/clickzetta/__init__.py b/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/clickzetta/__init__.py rename to api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/__init__.py diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py rename to api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py index a4dddc68f0..72b8c5e9eb 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/providers/vdb/vdb-clickzetta/src/dify_vdb_clickzetta/clickzetta_vector.py @@ -51,7 +51,7 @@ class ClickzettaConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): """ Validate the configuration values. """ diff --git a/api/tests/integration_tests/vdb/clickzetta/README.md b/api/providers/vdb/vdb-clickzetta/tests/README.md similarity index 100% rename from api/tests/integration_tests/vdb/clickzetta/README.md rename to api/providers/vdb/vdb-clickzetta/tests/README.md diff --git a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py b/api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_clickzetta.py similarity index 92% rename from api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py rename to api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_clickzetta.py index 21de8be6e3..1c6819f9f1 100644 --- a/api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py +++ b/api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_clickzetta.py @@ -2,10 +2,10 @@ import contextlib import os import pytest +from dify_vdb_clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector -from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text from core.rag.models.document import Document -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis class TestClickzettaVector(AbstractVectorTest): @@ -14,9 +14,8 @@ class TestClickzettaVector(AbstractVectorTest): """ @pytest.fixture - def vector_store(self): + def vector_store(self, setup_mock_redis): """Create a Clickzetta vector store instance for testing.""" - # Skip test if Clickzetta credentials are not configured if not os.getenv("CLICKZETTA_USERNAME"): pytest.skip("CLICKZETTA_USERNAME is not configured") if not os.getenv("CLICKZETTA_PASSWORD"): @@ -32,21 +31,19 @@ class TestClickzettaVector(AbstractVectorTest): workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"), vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"), schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"), - batch_size=10, # Small batch size for testing + batch_size=10, enable_inverted_index=True, analyzer_type="chinese", analyzer_mode="smart", vector_distance_function="cosine_distance", ) - with setup_mock_redis(): - vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config) + vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config) - yield vector + yield vector - # Cleanup: delete the test collection - with contextlib.suppress(Exception): - vector.delete() + with contextlib.suppress(Exception): + vector.delete() def test_clickzetta_vector_basic_operations(self, vector_store): """Test basic CRUD operations on Clickzetta vector store.""" diff --git a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py b/api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_docker_integration.py similarity index 55% rename from api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py rename to api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_docker_integration.py index 60e3f30f26..a5d32f5e81 100644 --- a/api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py +++ b/api/providers/vdb/vdb-clickzetta/tests/integration_tests/test_docker_integration.py @@ -3,16 +3,19 @@ Test Clickzetta integration in Docker environment """ +import logging import os import time import httpx from clickzetta import connect +logger = logging.getLogger(__name__) + def test_clickzetta_connection(): """Test direct connection to Clickzetta""" - print("=== Testing direct Clickzetta connection ===") + logger.info("=== Testing direct Clickzetta connection ===") try: conn = connect( username=os.getenv("CLICKZETTA_USERNAME", "test_user"), @@ -25,100 +28,93 @@ def test_clickzetta_connection(): ) with conn.cursor() as cursor: - # Test basic connectivity cursor.execute("SELECT 1 as test") result = cursor.fetchone() - print(f"✓ Connection test: {result}") + logger.info("✓ Connection test: %s", result) - # Check if our test table exists cursor.execute("SHOW TABLES IN dify") tables = cursor.fetchall() - print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}") + logger.info("✓ Existing tables: %s", [t[1] for t in tables if t[0] == "dify"]) - # Check if test collection exists test_collection = "collection_test_dataset" if test_collection in [t[1] for t in tables if t[0] == "dify"]: cursor.execute(f"DESCRIBE dify.{test_collection}") columns = cursor.fetchall() - print(f"✓ Table structure for {test_collection}:") + logger.info("✓ Table structure for %s:", test_collection) for col in columns: - print(f" - {col[0]}: {col[1]}") + logger.info(" - %s: %s", col[0], col[1]) - # Check for indexes cursor.execute(f"SHOW INDEXES IN dify.{test_collection}") indexes = cursor.fetchall() - print(f"✓ Indexes on {test_collection}:") + logger.info("✓ Indexes on %s:", test_collection) for idx in indexes: - print(f" - {idx}") + logger.info(" - %s", idx) return True - except Exception as e: - print(f"✗ Connection test failed: {e}") + except Exception: + logger.exception("✗ Connection test failed") return False def test_dify_api(): """Test Dify API with Clickzetta backend""" - print("\n=== Testing Dify API ===") + logger.info("\n=== Testing Dify API ===") base_url = "http://localhost:5001" - # Wait for API to be ready max_retries = 30 for i in range(max_retries): try: response = httpx.get(f"{base_url}/console/api/health") if response.status_code == 200: - print("✓ Dify API is ready") + logger.info("✓ Dify API is ready") break except: if i == max_retries - 1: - print("✗ Dify API is not responding") + logger.exception("✗ Dify API is not responding") return False time.sleep(2) - # Check vector store configuration try: - # This is a simplified check - in production, you'd use proper auth - print("✓ Dify is configured to use Clickzetta as vector store") + logger.info("✓ Dify is configured to use Clickzetta as vector store") return True - except Exception as e: - print(f"✗ API test failed: {e}") + except Exception: + logger.exception("✗ API test failed") return False def verify_table_structure(): """Verify the table structure meets Dify requirements""" - print("\n=== Verifying Table Structure ===") + logger.info("\n=== Verifying Table Structure ===") expected_columns = { "id": "VARCHAR", "page_content": "VARCHAR", - "metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta + "metadata": "VARCHAR", "vector": "ARRAY", } expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"] - print("✓ Expected table structure:") + logger.info("✓ Expected table structure:") for col, dtype in expected_columns.items(): - print(f" - {col}: {dtype}") + logger.info(" - %s: %s", col, dtype) - print("\n✓ Required metadata fields:") + logger.info("\n✓ Required metadata fields:") for field in expected_metadata_fields: - print(f" - {field}") + logger.info(" - %s", field) - print("\n✓ Index requirements:") - print(" - Vector index (HNSW) on 'vector' column") - print(" - Full-text index on 'page_content' (optional)") - print(" - Functional index on metadata->>'$.doc_id' (recommended)") - print(" - Functional index on metadata->>'$.document_id' (recommended)") + logger.info("\n✓ Index requirements:") + logger.info(" - Vector index (HNSW) on 'vector' column") + logger.info(" - Full-text index on 'page_content' (optional)") + logger.info(" - Functional index on metadata->>'$.doc_id' (recommended)") + logger.info(" - Functional index on metadata->>'$.document_id' (recommended)") return True def main(): """Run all tests""" - print("Starting Clickzetta integration tests for Dify Docker\n") + logger.info("Starting Clickzetta integration tests for Dify Docker\n") tests = [ ("Direct Clickzetta Connection", test_clickzetta_connection), @@ -131,33 +127,34 @@ def main(): try: success = test_func() results.append((test_name, success)) - except Exception as e: - print(f"\n✗ {test_name} crashed: {e}") + except Exception: + logger.exception("\n✗ %s crashed", test_name) results.append((test_name, False)) - # Summary - print("\n" + "=" * 50) - print("Test Summary:") - print("=" * 50) + logger.info("\n%s", "=" * 50) + logger.info("Test Summary:") + logger.info("=" * 50) passed = sum(1 for _, success in results if success) total = len(results) for test_name, success in results: status = "✅ PASSED" if success else "❌ FAILED" - print(f"{test_name}: {status}") + logger.info("%s: %s", test_name, status) - print(f"\nTotal: {passed}/{total} tests passed") + logger.info("\nTotal: %s/%s tests passed", passed, total) if passed == total: - print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.") - print("\nNext steps:") - print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d") - print("2. Access Dify at http://localhost:3000") - print("3. Create a dataset and test vector storage with Clickzetta") + logger.info("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.") + logger.info("\nNext steps:") + logger.info( + "1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d" + ) + logger.info("2. Access Dify at http://localhost:3000") + logger.info("3. Create a dataset and test vector storage with Clickzetta") return 0 else: - print("\n⚠️ Some tests failed. Please check the errors above.") + logger.error("\n⚠️ Some tests failed. Please check the errors above.") return 1 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py b/api/providers/vdb/vdb-clickzetta/tests/unit_tests/test_clickzetta_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py rename to api/providers/vdb/vdb-clickzetta/tests/unit_tests/test_clickzetta_vector.py index 0ce5c04dd6..a7473f1b91 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/clickzetta/test_clickzetta_vector.py +++ b/api/providers/vdb/vdb-clickzetta/tests/unit_tests/test_clickzetta_vector.py @@ -47,7 +47,7 @@ def _build_fake_clickzetta_module(): @pytest.fixture def clickzetta_module(monkeypatch): monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module()) - import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module + import dify_vdb_clickzetta.clickzetta_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-couchbase/pyproject.toml b/api/providers/vdb/vdb-couchbase/pyproject.toml new file mode 100644 index 0000000000..6bc348b2eb --- /dev/null +++ b/api/providers/vdb/vdb-couchbase/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-couchbase" +version = "0.0.1" + +dependencies = [ + "couchbase~=4.6.0", +] +description = "Dify vector store backend (dify-vdb-couchbase)." + +[project.entry-points."dify.vector_backends"] +couchbase = "dify_vdb_couchbase.couchbase_vector:CouchbaseVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/myscale/__init__.py b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/myscale/__init__.py rename to api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/__init__.py diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py similarity index 98% rename from api/core/rag/datasource/vdb/couchbase/couchbase_vector.py rename to api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py index 9a4a65cf6f..bab176e285 100644 --- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py +++ b/api/providers/vdb/vdb-couchbase/src/dify_vdb_couchbase/couchbase_vector.py @@ -36,7 +36,7 @@ class CouchbaseConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values.get("connection_string"): raise ValueError("config COUCHBASE_CONNECTION_STRING is required") if not values.get("user"): @@ -59,7 +59,7 @@ class CouchbaseVector(BaseVector): auth = PasswordAuthenticator(config.user, config.password) options = ClusterOptions(auth) - self._cluster = Cluster(config.connection_string, options) + self._cluster = Cluster(config.connection_string, options) # pyright: ignore[reportArgumentType] self._bucket = self._cluster.bucket(config.bucket_name) self._scope = self._bucket.scope(config.scope_name) self._bucket_name = config.bucket_name @@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector): def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 4) try: - CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) + CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # pyright: ignore[reportCallIssue] search_iter = self._scope.search( self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"]) ) diff --git a/api/tests/integration_tests/vdb/couchbase/test_couchbase.py b/api/providers/vdb/vdb-couchbase/tests/integration_tests/test_couchbase.py similarity index 80% rename from api/tests/integration_tests/vdb/couchbase/test_couchbase.py rename to api/providers/vdb/vdb-couchbase/tests/integration_tests/test_couchbase.py index 0371f04233..918dae328f 100644 --- a/api/tests/integration_tests/vdb/couchbase/test_couchbase.py +++ b/api/providers/vdb/vdb-couchbase/tests/integration_tests/test_couchbase.py @@ -1,12 +1,14 @@ +import logging import subprocess import time -from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +logger = logging.getLogger(__name__) def wait_for_healthy_container(service_name="couchbase-server", timeout=300): @@ -16,10 +18,10 @@ def wait_for_healthy_container(service_name="couchbase-server", timeout=300): ["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True ) if result.stdout.strip() == "healthy": - print(f"{service_name} is healthy!") + logger.info("%s is healthy!", service_name) return True else: - print(f"Waiting for {service_name} to be healthy...") + logger.info("Waiting for %s to be healthy...", service_name) time.sleep(10) raise TimeoutError(f"{service_name} did not become healthy in time") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py b/api/providers/vdb/vdb-couchbase/tests/unit_tests/test_couchbase_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py rename to api/providers/vdb/vdb-couchbase/tests/unit_tests/test_couchbase_vector.py index 9fea187615..7e5c40b8f2 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/couchbase/test_couchbase_vector.py +++ b/api/providers/vdb/vdb-couchbase/tests/unit_tests/test_couchbase_vector.py @@ -154,7 +154,7 @@ def couchbase_module(monkeypatch): for name, module in _build_fake_couchbase_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.couchbase.couchbase_vector as module + import dify_vdb_couchbase.couchbase_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-elasticsearch/pyproject.toml b/api/providers/vdb/vdb-elasticsearch/pyproject.toml new file mode 100644 index 0000000000..d40908f92d --- /dev/null +++ b/api/providers/vdb/vdb-elasticsearch/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "dify-vdb-elasticsearch" +version = "0.0.1" + +dependencies = [ + "elasticsearch==8.14.0", +] +description = "Dify vector store backend (dify-vdb-elasticsearch)." + +[project.entry-points."dify.vector_backends"] +elasticsearch = "dify_vdb_elasticsearch.elasticsearch_vector:ElasticSearchVectorFactory" +elasticsearch-ja = "dify_vdb_elasticsearch.elasticsearch_ja_vector:ElasticSearchJaVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/oceanbase/__init__.py b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/oceanbase/__init__.py rename to api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/__init__.py diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py similarity index 97% rename from api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py rename to api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py index 1e7fe52666..e2f390402a 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_ja_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_ja_vector.py @@ -4,14 +4,14 @@ from typing import Any from flask import current_app -from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ( +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from dify_vdb_elasticsearch.elasticsearch_vector import ( ElasticSearchConfig, ElasticSearchVector, ElasticSearchVectorFactory, ) -from core.rag.datasource.vdb.field import Field -from core.rag.datasource.vdb.vector_type import VectorType -from core.rag.embedding.embedding_base import Embeddings from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -23,7 +23,7 @@ class ElasticSearchJaVector(ElasticSearchVector): self, embeddings: list[list[float]], metadatas: list[dict[Any, Any]] | None = None, - index_params: dict | None = None, + index_params: dict[str, Any] | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py rename to api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_vector.py index 1470713b88..11463b6c58 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/src/dify_vdb_elasticsearch/elasticsearch_vector.py @@ -43,7 +43,7 @@ class ElasticSearchConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): use_cloud = values.get("use_cloud", False) cloud_url = values.get("cloud_url") @@ -258,7 +258,7 @@ class ElasticSearchVector(BaseVector): self, embeddings: list[list[float]], metadatas: list[dict[Any, Any]] | None = None, - index_params: dict | None = None, + index_params: dict[str, Any] | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py b/api/providers/vdb/vdb-elasticsearch/tests/integration_tests/test_elasticsearch.py similarity index 71% rename from api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py rename to api/providers/vdb/vdb-elasticsearch/tests/integration_tests/test_elasticsearch.py index 970d2cce1a..c8b679e021 100644 --- a/api/tests/integration_tests/vdb/elasticsearch/test_elasticsearch.py +++ b/api/providers/vdb/vdb-elasticsearch/tests/integration_tests/test_elasticsearch.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class ElasticSearchVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py b/api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_ja_vector.py similarity index 96% rename from api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py rename to api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_ja_vector.py index edd29a4649..f81ed6beea 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_ja_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_ja_vector.py @@ -32,8 +32,8 @@ def elasticsearch_ja_module(monkeypatch): for name, module in _build_fake_elasticsearch_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module - import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module + import dify_vdb_elasticsearch.elasticsearch_ja_vector as ja_module + import dify_vdb_elasticsearch.elasticsearch_vector as base_module importlib.reload(base_module) return importlib.reload(ja_module) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py b/api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py rename to api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_vector.py index 9ecf0caa24..48f1f6dc26 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/elasticsearch/test_elasticsearch_vector.py +++ b/api/providers/vdb/vdb-elasticsearch/tests/unit_tests/test_elasticsearch_vector.py @@ -42,7 +42,7 @@ def elasticsearch_module(monkeypatch): for name, module in _build_fake_elasticsearch_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module + import dify_vdb_elasticsearch.elasticsearch_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-hologres/pyproject.toml b/api/providers/vdb/vdb-hologres/pyproject.toml new file mode 100644 index 0000000000..88044bf6d6 --- /dev/null +++ b/api/providers/vdb/vdb-hologres/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-hologres" +version = "0.0.1" + +dependencies = [ + "holo-search-sdk>=0.4.2", +] +description = "Dify vector store backend (dify-vdb-hologres)." + +[project.entry-points."dify.vector_backends"] +hologres = "dify_vdb_hologres.hologres_vector:HologresVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/opengauss/__init__.py b/api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/opengauss/__init__.py rename to api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/__init__.py diff --git a/api/core/rag/datasource/vdb/hologres/hologres_vector.py b/api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/hologres_vector.py similarity index 97% rename from api/core/rag/datasource/vdb/hologres/hologres_vector.py rename to api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/hologres_vector.py index 13d48b5668..80c0ed582e 100644 --- a/api/core/rag/datasource/vdb/hologres/hologres_vector.py +++ b/api/providers/vdb/vdb-hologres/src/dify_vdb_hologres/hologres_vector.py @@ -1,7 +1,7 @@ import json import logging import time -from typing import Any +from typing import Any, cast import holo_search_sdk as holo # type: ignore from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType @@ -43,7 +43,7 @@ class HologresVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values.get("host"): raise ValueError("config HOLOGRES_HOST is required") if not values.get("database"): @@ -351,9 +351,9 @@ class HologresVectorFactory(AbstractVectorFactory): access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "", access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "", schema_name=dify_config.HOLOGRES_SCHEMA, - tokenizer=dify_config.HOLOGRES_TOKENIZER, - distance_method=dify_config.HOLOGRES_DISTANCE_METHOD, - base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE, + tokenizer=cast(TokenizerType, dify_config.HOLOGRES_TOKENIZER), + distance_method=cast(DistanceType, dify_config.HOLOGRES_DISTANCE_METHOD), + base_quantization_type=cast(BaseQuantizationType, dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE), max_degree=dify_config.HOLOGRES_MAX_DEGREE, ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION, ), diff --git a/api/tests/integration_tests/vdb/__mock/hologres.py b/api/providers/vdb/vdb-hologres/tests/integration_tests/conftest.py similarity index 82% rename from api/tests/integration_tests/vdb/__mock/hologres.py rename to api/providers/vdb/vdb-hologres/tests/integration_tests/conftest.py index b60cf358c0..d28ded0187 100644 --- a/api/tests/integration_tests/vdb/__mock/hologres.py +++ b/api/providers/vdb/vdb-hologres/tests/integration_tests/conftest.py @@ -7,13 +7,10 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from psycopg import sql as psql -# Shared in-memory storage: {table_name: {doc_id: {"id", "text", "meta", "embedding"}}} _mock_tables: dict[str, dict[str, dict[str, Any]]] = {} class MockSearchQuery: - """Mock query builder for search_vector and search_text results.""" - def __init__(self, table_name: str, search_type: str): self._table_name = table_name self._search_type = search_type @@ -32,17 +29,13 @@ class MockSearchQuery: return self def _apply_filter(self, row: dict[str, Any]) -> bool: - """Apply the filter SQL to check if a row matches.""" if self._filter_sql is None: return True - # Extract literals (the document IDs) from the filter SQL - # Filter format: meta->>'document_id' IN ('doc1', 'doc2') literals = [v for t, v in _extract_identifiers_and_literals(self._filter_sql) if t == "literal"] if not literals: return True - # Get the document_id from the row's meta field meta = row.get("meta", "{}") if isinstance(meta, str): meta = json.loads(meta) @@ -54,22 +47,17 @@ class MockSearchQuery: data = _mock_tables.get(self._table_name, {}) results = [] for row in list(data.values())[: self._limit_val]: - # Apply filter if present if not self._apply_filter(row): continue if self._search_type == "vector": - # row format expected by _process_vector_results: (distance, id, text, meta) results.append((0.1, row["id"], row["text"], row["meta"])) else: - # row format expected by _process_full_text_results: (id, text, meta, embedding, score) results.append((row["id"], row["text"], row["meta"], row.get("embedding", []), 0.9)) return results class MockTable: - """Mock table object returned by client.open_table().""" - def __init__(self, table_name: str): self._table_name = table_name @@ -97,7 +85,6 @@ class MockTable: def _extract_sql_template(query) -> str: - """Extract the SQL template string from a psycopg Composed object.""" if isinstance(query, psql.Composed): for part in query: if isinstance(part, psql.SQL): @@ -108,7 +95,6 @@ def _extract_sql_template(query) -> str: def _extract_identifiers_and_literals(query) -> list[Any]: - """Extract Identifier and Literal values from a psycopg Composed object.""" values: list[Any] = [] if isinstance(query, psql.Composed): for part in query: @@ -117,7 +103,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]: elif isinstance(part, psql.Literal): values.append(("literal", part._obj)) elif isinstance(part, psql.Composed): - # Handles SQL(...).join(...) for IN clauses for sub in part: if isinstance(sub, psql.Literal): values.append(("literal", sub._obj)) @@ -125,8 +110,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]: class MockHologresClient: - """Mock holo_search_sdk client that stores data in memory.""" - def connect(self): pass @@ -141,21 +124,18 @@ class MockHologresClient: params = _extract_identifiers_and_literals(query) if "CREATE TABLE" in template.upper(): - # Extract table name from first identifier table_name = next((v for t, v in params if t == "ident"), "unknown") if table_name not in _mock_tables: _mock_tables[table_name] = {} return None if "SELECT 1" in template: - # text_exists: SELECT 1 FROM {table} WHERE id = {id} LIMIT 1 table_name = next((v for t, v in params if t == "ident"), "") doc_id = next((v for t, v in params if t == "literal"), "") data = _mock_tables.get(table_name, {}) return [(1,)] if doc_id in data else [] if "SELECT id" in template: - # get_ids_by_metadata_field: SELECT id FROM {table} WHERE meta->>{key} = {value} table_name = next((v for t, v in params if t == "ident"), "") literals = [v for t, v in params if t == "literal"] key = literals[0] if len(literals) > 0 else "" @@ -166,12 +146,10 @@ class MockHologresClient: if "DELETE" in template.upper(): table_name = next((v for t, v in params if t == "ident"), "") if "id IN" in template: - # delete_by_ids ids_to_delete = [v for t, v in params if t == "literal"] for did in ids_to_delete: _mock_tables.get(table_name, {}).pop(did, None) elif "meta->>" in template: - # delete_by_metadata_field literals = [v for t, v in params if t == "literal"] key = literals[0] if len(literals) > 0 else "" value = literals[1] if len(literals) > 1 else "" @@ -190,7 +168,6 @@ class MockHologresClient: def mock_connect(**kwargs): - """Replacement for holo_search_sdk.connect() that returns a mock client.""" return MockHologresClient() diff --git a/api/tests/integration_tests/vdb/hologres/test_hologres.py b/api/providers/vdb/vdb-hologres/tests/integration_tests/test_hologres.py similarity index 94% rename from api/tests/integration_tests/vdb/hologres/test_hologres.py rename to api/providers/vdb/vdb-hologres/tests/integration_tests/test_hologres.py index d81e18841e..04024be4ae 100644 --- a/api/tests/integration_tests/vdb/hologres/test_hologres.py +++ b/api/providers/vdb/vdb-hologres/tests/integration_tests/test_hologres.py @@ -2,16 +2,11 @@ import os import uuid from typing import cast +from dify_vdb_hologres.hologres_vector import HologresVector, HologresVectorConfig from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType -from core.rag.datasource.vdb.hologres.hologres_vector import HologresVector, HologresVectorConfig +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text from core.rag.models.document import Document -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text - -pytest_plugins = ( - "tests.integration_tests.vdb.test_vector_store", - "tests.integration_tests.vdb.__mock.hologres", -) MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py b/api/providers/vdb/vdb-hologres/tests/unit_tests/test_hologres_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py rename to api/providers/vdb/vdb-hologres/tests/unit_tests/test_hologres_vector.py index 5d9e744ded..f9a557ecce 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/hologres/test_hologres_vector.py +++ b/api/providers/vdb/vdb-hologres/tests/unit_tests/test_hologres_vector.py @@ -42,7 +42,7 @@ def hologres_module(monkeypatch): for name, module in _build_fake_hologres_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.hologres.hologres_vector as module + import dify_vdb_hologres.hologres_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-huawei-cloud/pyproject.toml b/api/providers/vdb/vdb-huawei-cloud/pyproject.toml new file mode 100644 index 0000000000..71af56786c --- /dev/null +++ b/api/providers/vdb/vdb-huawei-cloud/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-huawei-cloud" +version = "0.0.1" + +dependencies = [ + "elasticsearch==8.14.0", +] +description = "Dify vector store backend (dify-vdb-huawei-cloud)." + +[project.entry-points."dify.vector_backends"] +huawei_cloud = "dify_vdb_huawei_cloud.huawei_cloud_vector:HuaweiCloudVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/opensearch/__init__.py b/api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/opensearch/__init__.py rename to api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/__init__.py diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/huawei_cloud_vector.py similarity index 98% rename from api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py rename to api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/huawei_cloud_vector.py index 90d6d98c63..d51075d2e8 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/providers/vdb/vdb-huawei-cloud/src/dify_vdb_huawei_cloud/huawei_cloud_vector.py @@ -44,7 +44,7 @@ class HuaweiCloudVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["hosts"]: raise ValueError("config HOSTS is required") return values @@ -169,7 +169,7 @@ class HuaweiCloudVector(BaseVector): self, embeddings: list[list[float]], metadatas: list[dict[Any, Any]] | None = None, - index_params: dict | None = None, + index_params: dict[str, Any] | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py b/api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/conftest.py similarity index 100% rename from api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py rename to api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/conftest.py diff --git a/api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py b/api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/test_huawei_cloud.py similarity index 69% rename from api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py rename to api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/test_huawei_cloud.py index 01f511358a..bb5f5b72ef 100644 --- a/api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py +++ b/api/providers/vdb/vdb-huawei-cloud/tests/integration_tests/test_huawei_cloud.py @@ -1,10 +1,6 @@ -from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text +from dify_vdb_huawei_cloud.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig -pytest_plugins = ( - "tests.integration_tests.vdb.test_vector_store", - "tests.integration_tests.vdb.__mock.huaweicloudvectordb", -) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text class HuaweiCloudVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py b/api/providers/vdb/vdb-huawei-cloud/tests/unit_tests/test_huawei_cloud_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py rename to api/providers/vdb/vdb-huawei-cloud/tests/unit_tests/test_huawei_cloud_vector.py index 9d23dfcf63..ba3f14912b 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/huawei/test_huawei_cloud_vector.py +++ b/api/providers/vdb/vdb-huawei-cloud/tests/unit_tests/test_huawei_cloud_vector.py @@ -33,7 +33,7 @@ def huawei_module(monkeypatch): for name, module in _build_fake_elasticsearch_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module + import dify_vdb_huawei_cloud.huawei_cloud_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-iris/pyproject.toml b/api/providers/vdb/vdb-iris/pyproject.toml new file mode 100644 index 0000000000..6dd7a8e073 --- /dev/null +++ b/api/providers/vdb/vdb-iris/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-iris" +version = "0.0.1" + +dependencies = [ + "intersystems-irispython>=5.1.0", +] +description = "Dify vector store backend (dify-vdb-iris)." + +[project.entry-points."dify.vector_backends"] +iris = "dify_vdb_iris.iris_vector:IrisVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/oracle/__init__.py b/api/providers/vdb/vdb-iris/src/dify_vdb_iris/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/oracle/__init__.py rename to api/providers/vdb/vdb-iris/src/dify_vdb_iris/__init__.py diff --git a/api/core/rag/datasource/vdb/iris/iris_vector.py b/api/providers/vdb/vdb-iris/src/dify_vdb_iris/iris_vector.py similarity index 100% rename from api/core/rag/datasource/vdb/iris/iris_vector.py rename to api/providers/vdb/vdb-iris/src/dify_vdb_iris/iris_vector.py diff --git a/api/tests/integration_tests/vdb/iris/test_iris.py b/api/providers/vdb/vdb-iris/tests/integration_tests/test_iris.py similarity index 85% rename from api/tests/integration_tests/vdb/iris/test_iris.py rename to api/providers/vdb/vdb-iris/tests/integration_tests/test_iris.py index 4b2da8387b..8281e89c8a 100644 --- a/api/tests/integration_tests/vdb/iris/test_iris.py +++ b/api/providers/vdb/vdb-iris/tests/integration_tests/test_iris.py @@ -1,12 +1,11 @@ """Integration tests for IRIS vector database.""" -from core.rag.datasource.vdb.iris.iris_vector import IrisVector, IrisVectorConfig -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_iris.iris_vector import IrisVector, IrisVectorConfig + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class IrisVectorTest(AbstractVectorTest): """Test suite for IRIS vector store implementation.""" diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py b/api/providers/vdb/vdb-iris/tests/unit_tests/test_iris_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py rename to api/providers/vdb/vdb-iris/tests/unit_tests/test_iris_vector.py index 63338ca809..8c038e82b9 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/iris/test_iris_vector.py +++ b/api/providers/vdb/vdb-iris/tests/unit_tests/test_iris_vector.py @@ -26,7 +26,7 @@ def _build_fake_iris_module(): def iris_module(monkeypatch): monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module()) - import core.rag.datasource.vdb.iris.iris_vector as module + import dify_vdb_iris.iris_vector as module reloaded = importlib.reload(module) reloaded._pool_instance = None diff --git a/api/providers/vdb/vdb-lindorm/pyproject.toml b/api/providers/vdb/vdb-lindorm/pyproject.toml new file mode 100644 index 0000000000..0cffc67491 --- /dev/null +++ b/api/providers/vdb/vdb-lindorm/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "dify-vdb-lindorm" +version = "0.0.1" + +dependencies = [ + "opensearch-py==3.1.0", + "tenacity>=8.0.0", +] +description = "Dify vector store backend (dify-vdb-lindorm)." + +[project.entry-points."dify.vector_backends"] +lindorm = "dify_vdb_lindorm.lindorm_vector:LindormVectorStoreFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/__init__.py b/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/pgvecto_rs/__init__.py rename to api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/__init__.py diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py similarity index 98% rename from api/core/rag/datasource/vdb/lindorm/lindorm_vector.py rename to api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py index fbe0bcad02..9187ca943d 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py @@ -44,7 +44,7 @@ class LindormVectorStoreConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["hosts"]: raise ValueError("config URL is required") if not values["username"]: @@ -336,7 +336,10 @@ class LindormVectorStore(BaseVector): return docs def create_collection( - self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None + self, + embeddings: list[list[float]], + metadatas: list[dict[str, Any]] | None = None, + index_params: dict[str, Any] | None = None, ): if not embeddings: raise ValueError(f"Embeddings list cannot be empty for collection create '{self._collection_name}'") diff --git a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py b/api/providers/vdb/vdb-lindorm/tests/integration_tests/test_lindorm.py similarity index 88% rename from api/tests/integration_tests/vdb/lindorm/test_lindorm.py rename to api/providers/vdb/vdb-lindorm/tests/integration_tests/test_lindorm.py index b24498fdfd..0a0c2d2d59 100644 --- a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py +++ b/api/providers/vdb/vdb-lindorm/tests/integration_tests/test_lindorm.py @@ -1,9 +1,8 @@ import os -from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest +from dify_vdb_lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest class Config: diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py b/api/providers/vdb/vdb-lindorm/tests/unit_tests/test_lindorm_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py rename to api/providers/vdb/vdb-lindorm/tests/unit_tests/test_lindorm_vector.py index 34357d5907..238145c1d6 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/lindorm/test_lindorm_vector.py +++ b/api/providers/vdb/vdb-lindorm/tests/unit_tests/test_lindorm_vector.py @@ -51,7 +51,7 @@ def lindorm_module(monkeypatch): for name, module in _build_fake_opensearch_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.lindorm.lindorm_vector as module + import dify_vdb_lindorm.lindorm_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-matrixone/pyproject.toml b/api/providers/vdb/vdb-matrixone/pyproject.toml new file mode 100644 index 0000000000..53363ed7d9 --- /dev/null +++ b/api/providers/vdb/vdb-matrixone/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-matrixone" +version = "0.0.1" + +dependencies = [ + "mo-vector~=0.1.13", +] +description = "Dify vector store backend (dify-vdb-matrixone)." + +[project.entry-points."dify.vector_backends"] +matrixone = "dify_vdb_matrixone.matrixone_vector:MatrixoneVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/pgvector/__init__.py b/api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/pgvector/__init__.py rename to api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/__init__.py diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/matrixone_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/matrixone/matrixone_vector.py rename to api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/matrixone_vector.py index c6ebccd204..75fb54e6f4 100644 --- a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py +++ b/api/providers/vdb/vdb-matrixone/src/dify_vdb_matrixone/matrixone_vector.py @@ -43,7 +43,7 @@ class MatrixoneConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config host is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/matrixone/test_matrixone.py b/api/providers/vdb/vdb-matrixone/tests/integration_tests/test_matrixone.py similarity index 74% rename from api/tests/integration_tests/vdb/matrixone/test_matrixone.py rename to api/providers/vdb/vdb-matrixone/tests/integration_tests/test_matrixone.py index fe592f6699..d6f4781e65 100644 --- a/api/tests/integration_tests/vdb/matrixone/test_matrixone.py +++ b/api/providers/vdb/vdb-matrixone/tests/integration_tests/test_matrixone.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class MatrixoneVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py b/api/providers/vdb/vdb-matrixone/tests/unit_tests/test_matrixone_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py rename to api/providers/vdb/vdb-matrixone/tests/unit_tests/test_matrixone_vector.py index 55e7b9112e..c22f4304e5 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/matrixone/test_matrixone_vector.py +++ b/api/providers/vdb/vdb-matrixone/tests/unit_tests/test_matrixone_vector.py @@ -36,7 +36,7 @@ def matrixone_module(monkeypatch): for name, module in _build_fake_mo_vector_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.matrixone.matrixone_vector as module + import dify_vdb_matrixone.matrixone_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-milvus/pyproject.toml b/api/providers/vdb/vdb-milvus/pyproject.toml new file mode 100644 index 0000000000..57385a4431 --- /dev/null +++ b/api/providers/vdb/vdb-milvus/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-milvus" +version = "0.0.1" + +dependencies = [ + "pymilvus~=2.6.12", +] +description = "Dify vector store backend (dify-vdb-milvus)." + +[project.entry-points."dify.vector_backends"] +milvus = "dify_vdb_milvus.milvus_vector:MilvusVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/pyvastbase/__init__.py b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/pyvastbase/__init__.py rename to api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/__init__.py diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py similarity index 96% rename from api/core/rag/datasource/vdb/milvus/milvus_vector.py rename to api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py index 7cdb2d3a99..823b877707 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/providers/vdb/vdb-milvus/src/dify_vdb_milvus/milvus_vector.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, TypedDict +from typing import Any, TypedDict, cast from packaging import version from pydantic import BaseModel, model_validator @@ -45,7 +45,7 @@ class MilvusConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): """ Validate the configuration values. Raises ValueError if required fields are missing. @@ -92,7 +92,7 @@ class MilvusVector(BaseVector): def _load_collection_fields(self, fields: list[str] | None = None): if fields is None: # Load collection fields from remote server - collection_info = self._client.describe_collection(self._collection_name) + collection_info = cast(dict[str, Any], self._client.describe_collection(self._collection_name)) fields = [field["name"] for field in collection_info["fields"]] # Since primary field is auto-id, no need to track it self._fields = [f for f in fields if f != Field.PRIMARY_KEY] @@ -106,7 +106,8 @@ class MilvusVector(BaseVector): return False try: - milvus_version = self._client.get_server_version() + milvus_version_raw = self._client.get_server_version() + milvus_version = milvus_version_raw if isinstance(milvus_version_raw, str) else str(milvus_version_raw) # Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility if "Zilliz Cloud" in milvus_version: return True @@ -302,7 +303,10 @@ class MilvusVector(BaseVector): ) def create_collection( - self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None + self, + embeddings: list[list[float]], + metadatas: list[dict[str, Any]] | None = None, + index_params: dict[str, Any] | None = None, ): """ Create a new collection in Milvus with the specified schema and index parameters. diff --git a/api/tests/integration_tests/vdb/milvus/test_milvus.py b/api/providers/vdb/vdb-milvus/tests/integration_tests/test_milvus.py similarity index 80% rename from api/tests/integration_tests/vdb/milvus/test_milvus.py rename to api/providers/vdb/vdb-milvus/tests/integration_tests/test_milvus.py index b5fc4b4d10..084d808bed 100644 --- a/api/tests/integration_tests/vdb/milvus/test_milvus.py +++ b/api/providers/vdb/vdb-milvus/tests/integration_tests/test_milvus.py @@ -1,11 +1,10 @@ -from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_milvus.milvus_vector import MilvusConfig, MilvusVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, get_example_text, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class MilvusVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py rename to api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py index 2ac2c40d38..36c0ed8f6f 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/providers/vdb/vdb-milvus/tests/unit_tests/test_milvus.py @@ -103,7 +103,7 @@ def milvus_module(monkeypatch): for name, module in _build_fake_pymilvus_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.milvus.milvus_vector as module + import dify_vdb_milvus.milvus_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-myscale/pyproject.toml b/api/providers/vdb/vdb-myscale/pyproject.toml new file mode 100644 index 0000000000..13e0f35d23 --- /dev/null +++ b/api/providers/vdb/vdb-myscale/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-myscale" +version = "0.0.1" + +dependencies = [ + "clickhouse-connect~=0.15.0", +] +description = "Dify vector store backend (dify-vdb-myscale)." + +[project.entry-points."dify.vector_backends"] +myscale = "dify_vdb_myscale.myscale_vector:MyScaleVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/qdrant/__init__.py b/api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/qdrant/__init__.py rename to api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/__init__.py diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/myscale_vector.py similarity index 100% rename from api/core/rag/datasource/vdb/myscale/myscale_vector.py rename to api/providers/vdb/vdb-myscale/src/dify_vdb_myscale/myscale_vector.py diff --git a/api/tests/integration_tests/vdb/myscale/test_myscale.py b/api/providers/vdb/vdb-myscale/tests/integration_tests/test_myscale.py similarity index 76% rename from api/tests/integration_tests/vdb/myscale/test_myscale.py rename to api/providers/vdb/vdb-myscale/tests/integration_tests/test_myscale.py index 74cefad2af..8ea42d5f45 100644 --- a/api/tests/integration_tests/vdb/myscale/test_myscale.py +++ b/api/providers/vdb/vdb-myscale/tests/integration_tests/test_myscale.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleConfig, MyScaleVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_myscale.myscale_vector import MyScaleConfig, MyScaleVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class MyScaleVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py b/api/providers/vdb/vdb-myscale/tests/unit_tests/test_myscale_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py rename to api/providers/vdb/vdb-myscale/tests/unit_tests/test_myscale_vector.py index a75ba82238..228ea92639 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py +++ b/api/providers/vdb/vdb-myscale/tests/unit_tests/test_myscale_vector.py @@ -42,7 +42,7 @@ def myscale_module(monkeypatch): fake_module = _build_fake_clickhouse_connect_module() monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module) - import core.rag.datasource.vdb.myscale.myscale_vector as module + import dify_vdb_myscale.myscale_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-oceanbase/pyproject.toml b/api/providers/vdb/vdb-oceanbase/pyproject.toml new file mode 100644 index 0000000000..887869a41c --- /dev/null +++ b/api/providers/vdb/vdb-oceanbase/pyproject.toml @@ -0,0 +1,16 @@ +[project] +name = "dify-vdb-oceanbase" +version = "0.0.1" + +dependencies = [ + "pyobvector~=0.2.17", + "mysql-connector-python>=9.3.0", +] +description = "Dify vector store backend (dify-vdb-oceanbase)." + +[project.entry-points."dify.vector_backends"] +oceanbase = "dify_vdb_oceanbase.oceanbase_vector:OceanBaseVectorFactory" +seekdb = "dify_vdb_oceanbase.oceanbase_vector:OceanBaseVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/relyt/__init__.py b/api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/relyt/__init__.py rename to api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/__init__.py diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/oceanbase_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py rename to api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/oceanbase_vector.py index 82f419871c..69dc42169a 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/providers/vdb/vdb-oceanbase/src/dify_vdb_oceanbase/oceanbase_vector.py @@ -49,7 +49,7 @@ class OceanBaseVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config OCEANBASE_VECTOR_HOST is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/oceanbase/bench_oceanbase.py b/api/providers/vdb/vdb-oceanbase/tests/integration_tests/bench_oceanbase.py similarity index 87% rename from api/tests/integration_tests/vdb/oceanbase/bench_oceanbase.py rename to api/providers/vdb/vdb-oceanbase/tests/integration_tests/bench_oceanbase.py index 8b57be08c5..50f6736942 100644 --- a/api/tests/integration_tests/vdb/oceanbase/bench_oceanbase.py +++ b/api/providers/vdb/vdb-oceanbase/tests/integration_tests/bench_oceanbase.py @@ -2,11 +2,12 @@ Benchmark: OceanBase vector store — old (single-row) vs new (batch) insertion, metadata query with/without functional index, and vector search across metrics. -Usage: - uv run --project api python -m tests.integration_tests.vdb.oceanbase.bench_oceanbase +Usage (from repo root): + uv run --project api python api/packages/dify-vdb-oceanbase/tests/bench_oceanbase.py """ import json +import logging import random import statistics import time @@ -16,6 +17,8 @@ from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_d from sqlalchemy import JSON, Column, String, text from sqlalchemy.dialects.mysql import LONGTEXT +logger = logging.getLogger(__name__) + # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- @@ -114,7 +117,7 @@ def bench_metadata_query(client, table, doc_id, with_index=False): try: client.perform_raw_text_sql(f"CREATE INDEX idx_metadata_doc_id ON `{table}` ((metadata->>'$.document_id'))") except Exception: - pass # already exists + logger.debug("Index idx_metadata_doc_id already exists, skipping creation") sql = text(f"SELECT id FROM `{table}` WHERE metadata->>'$.document_id' = :val") times = [] @@ -164,11 +167,11 @@ def main(): client = _make_client() client_pooled = _make_client(pool_size=5, max_overflow=10, pool_recycle=3600, pool_pre_ping=True) - print("=" * 70) - print("OceanBase Vector Store — Performance Benchmark") - print(f" Endpoint : {HOST}:{PORT}") - print(f" Vec dim : {VEC_DIM}") - print("=" * 70) + logger.info("=" * 70) + logger.info("OceanBase Vector Store — Performance Benchmark") + logger.info(" Endpoint : %s:%s", HOST, PORT) + logger.info(" Vec dim : %s", VEC_DIM) + logger.info("=" * 70) # ------------------------------------------------------------------ # 1. Insertion benchmark @@ -187,10 +190,10 @@ def main(): t_batch = bench_insert_batch(client_pooled, tbl_batch, rows, batch_size=100) speedup = t_single / t_batch if t_batch > 0 else float("inf") - print(f"\n[Insert {n_docs} docs]") - print(f" Single-row : {t_single:.2f}s") - print(f" Batch(100) : {t_batch:.2f}s") - print(f" Speedup : {speedup:.1f}x") + logger.info("\n[Insert %s docs]", n_docs) + logger.info(" Single-row : %.2fs", t_single) + logger.info(" Batch(100) : %.2fs", t_batch) + logger.info(" Speedup : %.1fx", speedup) # ------------------------------------------------------------------ # 2. Metadata query benchmark (use the 1000-doc batch table) @@ -203,16 +206,16 @@ def main(): res = conn.execute(text(f"SELECT metadata->>'$.document_id' FROM `{tbl_meta}` LIMIT 1")) doc_id_1000 = res.fetchone()[0] - print("\n[Metadata filter query — 1000 rows, by document_id]") + logger.info("\n[Metadata filter query — 1000 rows, by document_id]") times_no_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=False) - print(f" Without index : {_fmt(times_no_idx)}") + logger.info(" Without index : %s", _fmt(times_no_idx)) times_with_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=True) - print(f" With index : {_fmt(times_with_idx)}") + logger.info(" With index : %s", _fmt(times_with_idx)) # ------------------------------------------------------------------ # 3. Vector search benchmark — across metrics # ------------------------------------------------------------------ - print("\n[Vector search — top-10, 20 queries each, on 1000 rows]") + logger.info("\n[Vector search — top-10, 20 queries each, on 1000 rows]") for metric in ["l2", "cosine", "inner_product"]: tbl_vs = f"bench_vs_{metric}" @@ -222,7 +225,7 @@ def main(): rows_vs, _ = _gen_rows(1000) bench_insert_batch(client_pooled, tbl_vs, rows_vs, batch_size=100) times = bench_vector_search(client_pooled, tbl_vs, metric, topk=10, n_queries=20) - print(f" {metric:15s}: {_fmt(times)}") + logger.info(" %-15s: %s", metric, _fmt(times)) _drop(client_pooled, tbl_vs) # ------------------------------------------------------------------ @@ -232,9 +235,9 @@ def main(): _drop(client, f"bench_single_{n}") _drop(client, f"bench_batch_{n}") - print("\n" + "=" * 70) - print("Benchmark complete.") - print("=" * 70) + logger.info("\n%s", "=" * 70) + logger.info("Benchmark complete.") + logger.info("=" * 70) if __name__ == "__main__": diff --git a/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py b/api/providers/vdb/vdb-oceanbase/tests/integration_tests/test_oceanbase.py similarity index 82% rename from api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py rename to api/providers/vdb/vdb-oceanbase/tests/integration_tests/test_oceanbase.py index 410de2c5ad..28f22d3cbc 100644 --- a/api/tests/integration_tests/vdb/oceanbase/test_oceanbase.py +++ b/api/providers/vdb/vdb-oceanbase/tests/integration_tests/test_oceanbase.py @@ -1,15 +1,13 @@ import pytest - -from core.rag.datasource.vdb.oceanbase.oceanbase_vector import ( +from dify_vdb_oceanbase.oceanbase_vector import ( OceanBaseVector, OceanBaseVectorConfig, ) -from tests.integration_tests.vdb.test_vector_store import ( + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - @pytest.fixture def oceanbase_vector(): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py b/api/providers/vdb/vdb-oceanbase/tests/unit_tests/test_oceanbase_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py rename to api/providers/vdb/vdb-oceanbase/tests/unit_tests/test_oceanbase_vector.py index 27d8198ec0..31f9ff3e56 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/oceanbase/test_oceanbase_vector.py +++ b/api/providers/vdb/vdb-oceanbase/tests/unit_tests/test_oceanbase_vector.py @@ -56,7 +56,7 @@ def _build_fake_pyobvector_module(): def oceanbase_module(monkeypatch): monkeypatch.setitem(sys.modules, "pyobvector", _build_fake_pyobvector_module()) - import core.rag.datasource.vdb.oceanbase.oceanbase_vector as module + import dify_vdb_oceanbase.oceanbase_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-opengauss/pyproject.toml b/api/providers/vdb/vdb-opengauss/pyproject.toml new file mode 100644 index 0000000000..79be94b9e3 --- /dev/null +++ b/api/providers/vdb/vdb-opengauss/pyproject.toml @@ -0,0 +1,12 @@ +[project] +name = "dify-vdb-opengauss" +version = "0.0.1" + +dependencies = [] +description = "Dify vector store backend (dify-vdb-opengauss)." + +[project.entry-points."dify.vector_backends"] +opengauss = "dify_vdb_opengauss.opengauss:OpenGaussFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/tablestore/__init__.py b/api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/tablestore/__init__.py rename to api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/__init__.py diff --git a/api/core/rag/datasource/vdb/opengauss/opengauss.py b/api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/opengauss.py similarity index 99% rename from api/core/rag/datasource/vdb/opengauss/opengauss.py rename to api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/opengauss.py index f9dbfbeeaf..acd2471cf6 100644 --- a/api/core/rag/datasource/vdb/opengauss/opengauss.py +++ b/api/providers/vdb/vdb-opengauss/src/dify_vdb_opengauss/opengauss.py @@ -29,7 +29,7 @@ class OpenGaussConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config OPENGAUSS_HOST is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/opengauss/test_opengauss.py b/api/providers/vdb/vdb-opengauss/tests/integration_tests/test_opengauss.py similarity index 82% rename from api/tests/integration_tests/vdb/opengauss/test_opengauss.py rename to api/providers/vdb/vdb-opengauss/tests/integration_tests/test_opengauss.py index 78436a19ee..8b444527d7 100644 --- a/api/tests/integration_tests/vdb/opengauss/test_opengauss.py +++ b/api/providers/vdb/vdb-opengauss/tests/integration_tests/test_opengauss.py @@ -1,14 +1,12 @@ import time import psycopg2 +from dify_vdb_opengauss.opengauss import OpenGauss, OpenGaussConfig -from core.rag.datasource.vdb.opengauss.opengauss import OpenGauss, OpenGaussConfig -from tests.integration_tests.vdb.test_vector_store import ( +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class OpenGaussTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py b/api/providers/vdb/vdb-opengauss/tests/unit_tests/test_opengauss.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py rename to api/providers/vdb/vdb-opengauss/tests/unit_tests/test_opengauss.py index 6641dbe4a0..09abd625fc 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/opengauss/test_opengauss.py +++ b/api/providers/vdb/vdb-opengauss/tests/unit_tests/test_opengauss.py @@ -41,7 +41,7 @@ def opengauss_module(monkeypatch): for name, module in _build_fake_psycopg2_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.opengauss.opengauss as module + import dify_vdb_opengauss.opengauss as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-opensearch/pyproject.toml b/api/providers/vdb/vdb-opensearch/pyproject.toml new file mode 100644 index 0000000000..56f303fdf5 --- /dev/null +++ b/api/providers/vdb/vdb-opensearch/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-opensearch" +version = "0.0.1" + +dependencies = [ + "opensearch-py==3.1.0", +] +description = "Dify vector store backend (dify-vdb-opensearch)." + +[project.entry-points."dify.vector_backends"] +opensearch = "dify_vdb_opensearch.opensearch_vector:OpenSearchVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/tencent/__init__.py b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/tencent/__init__.py rename to api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/__init__.py diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py similarity index 98% rename from api/core/rag/datasource/vdb/opensearch/opensearch_vector.py rename to api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py index 50d18cdc4c..843c495d82 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py @@ -49,7 +49,7 @@ class OpenSearchConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values.get("host"): raise ValueError("config OPENSEARCH_HOST is required") if not values.get("port"): @@ -252,7 +252,10 @@ class OpenSearchVector(BaseVector): return docs def create_collection( - self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None + self, + embeddings: list[list[float]], + metadatas: list[dict[str, Any]] | None = None, + index_params: dict[str, Any] | None = None, ): lock_name = f"vector_indexing_lock_{self._collection_name.lower()}" with redis_client.lock(lock_name, timeout=20): diff --git a/api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch.py b/api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch.py new file mode 100644 index 0000000000..f2ed7cb6fb --- /dev/null +++ b/api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch.py @@ -0,0 +1,332 @@ +import importlib +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.rag.datasource.vdb.field import Field +from core.rag.models.document import Document +from extensions import ext_redis + + +def _build_fake_opensearch_modules(): + """Build fake opensearchpy modules to avoid the ``from events import Events`` + namespace collision (opensearch-py #756).""" + opensearchpy = types.ModuleType("opensearchpy") + opensearchpy_helpers = types.ModuleType("opensearchpy.helpers") + + class BulkIndexError(Exception): + def __init__(self, errors): + super().__init__("bulk error") + self.errors = errors + + class Urllib3AWSV4SignerAuth: + def __init__(self, credentials, region, service): + self.credentials = credentials + self.region = region + self.service = service + + class Urllib3HttpConnection: + pass + + class _IndicesClient: + def __init__(self): + self.exists = MagicMock(return_value=False) + self.create = MagicMock() + self.delete = MagicMock() + + class OpenSearch: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.indices = _IndicesClient() + self.search = MagicMock(return_value={"hits": {"hits": []}}) + self.get = MagicMock() + + helpers = SimpleNamespace(bulk=MagicMock()) + + opensearchpy.OpenSearch = OpenSearch + opensearchpy.Urllib3AWSV4SignerAuth = Urllib3AWSV4SignerAuth + opensearchpy.Urllib3HttpConnection = Urllib3HttpConnection + opensearchpy.helpers = helpers + opensearchpy_helpers.BulkIndexError = BulkIndexError + + return { + "opensearchpy": opensearchpy, + "opensearchpy.helpers": opensearchpy_helpers, + } + + +@pytest.fixture +def opensearch_module(monkeypatch): + for name, module in _build_fake_opensearch_modules().items(): + monkeypatch.setitem(sys.modules, name, module) + + import dify_vdb_opensearch.opensearch_vector as module + + return importlib.reload(module) + + +def _config(module, **overrides): + values = { + "host": "localhost", + "port": 9200, + "secure": False, + "user": "admin", + "password": "password", + } + values.update(overrides) + return module.OpenSearchConfig.model_validate(values) + + +def get_example_text() -> str: + return "This is a sample text for testing purposes." + + +class TestOpenSearchConfig: + def test_to_opensearch_params(self, opensearch_module): + config = _config(opensearch_module, secure=True) + params = config.to_opensearch_params() + + assert params["hosts"] == [{"host": "localhost", "port": 9200}] + assert params["use_ssl"] is True + assert params["verify_certs"] is True + assert params["connection_class"].__name__ == "Urllib3HttpConnection" + assert params["http_auth"] == ("admin", "password") + + def test_to_opensearch_params_with_aws_managed_iam(self, opensearch_module, monkeypatch): + class _Session: + def get_credentials(self): + return "creds" + + boto3 = types.ModuleType("boto3") + boto3.Session = _Session + monkeypatch.setitem(sys.modules, "boto3", boto3) + + config = _config( + opensearch_module, + secure=True, + auth_method="aws_managed_iam", + aws_region="ap-southeast-2", + aws_service="aoss", + host="aoss-endpoint.ap-southeast-2.aoss.amazonaws.com", + port=9201, + ) + params = config.to_opensearch_params() + + assert params["hosts"] == [{"host": "aoss-endpoint.ap-southeast-2.aoss.amazonaws.com", "port": 9201}] + assert params["use_ssl"] is True + assert params["verify_certs"] is True + assert params["connection_class"].__name__ == "Urllib3HttpConnection" + assert params["http_auth"].credentials == "creds" + assert params["http_auth"].region == "ap-southeast-2" + assert params["http_auth"].service == "aoss" + + +class TestOpenSearchVector: + COLLECTION_NAME = "test_collection" + EXAMPLE_DOC_ID = "example_doc_id" + + def _make_vector(self, module): + vector = module.OpenSearchVector(self.COLLECTION_NAME, _config(module)) + vector._client = MagicMock() + return vector + + @pytest.mark.parametrize( + ("search_response", "expected_length", "expected_doc_id"), + [ + ( + { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + "page_content": get_example_text(), + "metadata": {"document_id": "example_doc_id"}, + } + } + ], + } + }, + 1, + "example_doc_id", + ), + ({"hits": {"total": {"value": 0}, "hits": []}}, 0, None), + ], + ) + def test_search_by_full_text(self, opensearch_module, search_response, expected_length, expected_doc_id): + vector = self._make_vector(opensearch_module) + vector._client.search.return_value = search_response + + hits = vector.search_by_full_text(query=get_example_text()) + assert len(hits) == expected_length + if expected_length > 0: + assert hits[0].metadata["document_id"] == expected_doc_id + + def test_search_by_vector(self, opensearch_module): + vector = self._make_vector(opensearch_module) + query_vector = [0.1] * 128 + mock_response = { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + Field.CONTENT_KEY: get_example_text(), + Field.METADATA_KEY: {"document_id": self.EXAMPLE_DOC_ID}, + }, + "_score": 1.0, + } + ], + } + } + vector._client.search.return_value = mock_response + + hits = vector.search_by_vector(query_vector=query_vector) + + assert len(hits) > 0 + assert hits[0].metadata["document_id"] == self.EXAMPLE_DOC_ID + + def test_get_ids_by_metadata_field(self, opensearch_module): + vector = self._make_vector(opensearch_module) + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} + vector._client.search.return_value = mock_response + + doc = Document(page_content="Test content", metadata={"document_id": self.EXAMPLE_DOC_ID}) + embedding = [0.1] * 128 + + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts([doc], [embedding]) + + ids = vector.get_ids_by_metadata_field(key="document_id", value=self.EXAMPLE_DOC_ID) + assert len(ids) == 1 + assert ids[0] == "mock_id" + + def test_add_texts(self, opensearch_module): + vector = self._make_vector(opensearch_module) + vector._client.index.return_value = {"result": "created"} + + doc = Document(page_content="Test content", metadata={"document_id": self.EXAMPLE_DOC_ID}) + embedding = [0.1] * 128 + + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts([doc], [embedding]) + + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} + vector._client.search.return_value = mock_response + + ids = vector.get_ids_by_metadata_field(key="document_id", value=self.EXAMPLE_DOC_ID) + assert len(ids) == 1 + assert ids[0] == "mock_id" + + def test_delete_nonexistent_index(self, opensearch_module): + """ignore_unavailable=True handles non-existent indices gracefully.""" + vector = self._make_vector(opensearch_module) + vector.delete() + + vector._client.indices.delete.assert_called_once_with( + index=self.COLLECTION_NAME.lower(), ignore_unavailable=True + ) + + def test_delete_existing_index(self, opensearch_module): + vector = self._make_vector(opensearch_module) + vector.delete() + + vector._client.indices.delete.assert_called_once_with( + index=self.COLLECTION_NAME.lower(), ignore_unavailable=True + ) + + +@pytest.fixture(scope="module") +def setup_mock_redis(): + ext_redis.redis_client.get = MagicMock(return_value=None) + ext_redis.redis_client.set = MagicMock(return_value=None) + + mock_redis_lock = MagicMock() + mock_redis_lock.__enter__ = MagicMock() + mock_redis_lock.__exit__ = MagicMock() + ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock) + + +@pytest.mark.usefixtures("setup_mock_redis") +class TestOpenSearchVectorWithRedis: + COLLECTION_NAME = "test_collection" + EXAMPLE_DOC_ID = "example_doc_id" + + def _make_vector(self, module): + vector = module.OpenSearchVector(self.COLLECTION_NAME, _config(module)) + vector._client = MagicMock() + return vector + + def test_search_by_full_text(self, opensearch_module): + vector = self._make_vector(opensearch_module) + search_response = { + "hits": { + "total": {"value": 1}, + "hits": [ + {"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}} + ], + } + } + vector._client.search.return_value = search_response + + hits = vector.search_by_full_text(query=get_example_text()) + assert len(hits) == 1 + assert hits[0].metadata["document_id"] == "example_doc_id" + + def test_get_ids_by_metadata_field(self, opensearch_module): + vector = self._make_vector(opensearch_module) + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} + vector._client.search.return_value = mock_response + + doc = Document(page_content="Test content", metadata={"document_id": self.EXAMPLE_DOC_ID}) + embedding = [0.1] * 128 + + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts([doc], [embedding]) + + ids = vector.get_ids_by_metadata_field(key="document_id", value=self.EXAMPLE_DOC_ID) + assert len(ids) == 1 + assert ids[0] == "mock_id" + + def test_add_texts(self, opensearch_module): + vector = self._make_vector(opensearch_module) + vector._client.index.return_value = {"result": "created"} + + doc = Document(page_content="Test content", metadata={"document_id": self.EXAMPLE_DOC_ID}) + embedding = [0.1] * 128 + + opensearch_module.helpers.bulk.reset_mock() + vector.add_texts([doc], [embedding]) + + mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} + vector._client.search.return_value = mock_response + + ids = vector.get_ids_by_metadata_field(key="document_id", value=self.EXAMPLE_DOC_ID) + assert len(ids) == 1 + assert ids[0] == "mock_id" + + def test_search_by_vector(self, opensearch_module): + vector = self._make_vector(opensearch_module) + query_vector = [0.1] * 128 + mock_response = { + "hits": { + "total": {"value": 1}, + "hits": [ + { + "_source": { + Field.CONTENT_KEY: get_example_text(), + Field.METADATA_KEY: {"document_id": self.EXAMPLE_DOC_ID}, + }, + "_score": 1.0, + } + ], + } + } + vector._client.search.return_value = mock_response + + hits = vector.search_by_vector(query_vector=query_vector) + assert len(hits) > 0 + assert hits[0].metadata["document_id"] == self.EXAMPLE_DOC_ID diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py b/api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch_vector.py similarity index 98% rename from api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py rename to api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch_vector.py index 1030158dd1..1c2921f85b 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/opensearch/test_opensearch_vector.py +++ b/api/providers/vdb/vdb-opensearch/tests/unit_tests/test_opensearch_vector.py @@ -10,6 +10,8 @@ from pydantic import ValidationError from core.rag.models.document import Document +# TODO(wylswz): There's a known issue with namespace collision +# https://github.com/langgenius/dify/issues/34732 def _build_fake_opensearch_modules(): opensearchpy = types.ModuleType("opensearchpy") opensearchpy_helpers = types.ModuleType("opensearchpy.helpers") @@ -60,7 +62,7 @@ def opensearch_module(monkeypatch): for name, module in _build_fake_opensearch_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.opensearch.opensearch_vector as module + import dify_vdb_opensearch.opensearch_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-oracle/pyproject.toml b/api/providers/vdb/vdb-oracle/pyproject.toml new file mode 100644 index 0000000000..6747485041 --- /dev/null +++ b/api/providers/vdb/vdb-oracle/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-oracle" +version = "0.0.1" + +dependencies = [ + "oracledb==3.4.2", +] +description = "Dify vector store backend (dify-vdb-oracle)." + +[project.entry-points."dify.vector_backends"] +oracle = "dify_vdb_oracle.oraclevector:OracleVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py rename to api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/__init__.py diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py similarity index 95% rename from api/core/rag/datasource/vdb/oracle/oraclevector.py rename to api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py index cb05c22b55..5d9ab38529 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/providers/vdb/vdb-oracle/src/dify_vdb_oracle/oraclevector.py @@ -3,7 +3,7 @@ import json import logging import re import uuid -from typing import Any +from typing import Any, TypedDict import jieba.posseg as pseg # type: ignore import numpy @@ -25,6 +25,18 @@ logger = logging.getLogger(__name__) oracledb.defaults.fetch_lobs = False +class _OraclePoolParams(TypedDict, total=False): + user: str + password: str + dsn: str + min: int + max: int + increment: int + config_dir: str | None + wallet_location: str | None + wallet_password: str | None + + class OracleVectorConfig(BaseModel): user: str password: str @@ -36,7 +48,7 @@ class OracleVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["user"]: raise ValueError("config ORACLE_USER is required") if not values["password"]: @@ -127,22 +139,18 @@ class OracleVector(BaseVector): return connection def _create_connection_pool(self, config: OracleVectorConfig): - pool_params = { - "user": config.user, - "password": config.password, - "dsn": config.dsn, - "min": 1, - "max": 5, - "increment": 1, - } + pool_params = _OraclePoolParams( + user=config.user, + password=config.password, + dsn=config.dsn, + min=1, + max=5, + increment=1, + ) if config.is_autonomous: - pool_params.update( - { - "config_dir": config.config_dir, - "wallet_location": config.wallet_location, - "wallet_password": config.wallet_password, - } - ) + pool_params["config_dir"] = config.config_dir + pool_params["wallet_location"] = config.wallet_location + pool_params["wallet_password"] = config.wallet_password return oracledb.create_pool(**pool_params) def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): diff --git a/api/tests/integration_tests/vdb/oracle/test_oraclevector.py b/api/providers/vdb/vdb-oracle/tests/integration_tests/test_oraclevector.py similarity index 76% rename from api/tests/integration_tests/vdb/oracle/test_oraclevector.py rename to api/providers/vdb/vdb-oracle/tests/integration_tests/test_oraclevector.py index 8920dc97eb..aceb41289c 100644 --- a/api/tests/integration_tests/vdb/oracle/test_oraclevector.py +++ b/api/providers/vdb/vdb-oracle/tests/integration_tests/test_oraclevector.py @@ -1,11 +1,10 @@ -from core.rag.datasource.vdb.oracle.oraclevector import OracleVector, OracleVectorConfig -from core.rag.models.document import Document -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_oracle.oraclevector import OracleVector, OracleVectorConfig + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, get_example_text, ) - -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +from core.rag.models.document import Document class OracleVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py b/api/providers/vdb/vdb-oracle/tests/unit_tests/test_oraclevector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py rename to api/providers/vdb/vdb-oracle/tests/unit_tests/test_oraclevector.py index 817a7d342b..678cf876b0 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/oracle/test_oraclevector.py +++ b/api/providers/vdb/vdb-oracle/tests/unit_tests/test_oraclevector.py @@ -55,7 +55,7 @@ def oracle_module(monkeypatch): for name, module in _build_fake_oracle_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.oracle.oraclevector as module + import dify_vdb_oracle.oraclevector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-pgvecto-rs/pyproject.toml b/api/providers/vdb/vdb-pgvecto-rs/pyproject.toml new file mode 100644 index 0000000000..9a25442e9e --- /dev/null +++ b/api/providers/vdb/vdb-pgvecto-rs/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-pgvecto-rs" +version = "0.0.1" + +dependencies = [ + "pgvecto-rs[sqlalchemy]~=0.2.2", +] +description = "Dify vector store backend (dify-vdb-pgvecto-rs)." + +[project.entry-points."dify.vector_backends"] +pgvecto-rs = "dify_vdb_pgvecto_rs.pgvecto_rs:PGVectoRSFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/tidb_vector/__init__.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/tidb_vector/__init__.py rename to api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/__init__.py diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/collection.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py similarity index 80% rename from api/core/rag/datasource/vdb/pgvecto_rs/collection.py rename to api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py index c335bc610d..e087ec30a5 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/collection.py +++ b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/collection.py @@ -1,3 +1,4 @@ +from typing import Any from uuid import UUID from numpy import ndarray @@ -8,5 +9,5 @@ class CollectionORM(DeclarativeBase): __tablename__: str id: Mapped[UUID] text: Mapped[str] - meta: Mapped[dict] + meta: Mapped[dict[str, Any]] vector: Mapped[ndarray] diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py similarity index 98% rename from api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py rename to api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py index 387e918c76..9c721c8bde 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/providers/vdb/vdb-pgvecto-rs/src/dify_vdb_pgvecto_rs/pgvecto_rs.py @@ -12,12 +12,12 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column, sessionmaker from configs import dify_config -from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document +from dify_vdb_pgvecto_rs.collection import CollectionORM from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -33,7 +33,7 @@ class PgvectoRSConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config PGVECTO_RS_HOST is required") if not values["port"]: @@ -67,7 +67,7 @@ class PGVectoRS(BaseVector): primary_key=True, ) text: Mapped[str] - meta: Mapped[dict] = mapped_column(postgresql.JSONB) + meta: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB) vector: Mapped[ndarray] = mapped_column(VECTOR(dim)) self._table = _Table diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/providers/vdb/vdb-pgvecto-rs/tests/integration_tests/test_pgvecto_rs.py similarity index 82% rename from api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py rename to api/providers/vdb/vdb-pgvecto-rs/tests/integration_tests/test_pgvecto_rs.py index 6210613d42..9fc8627851 100644 --- a/api/tests/integration_tests/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/providers/vdb/vdb-pgvecto-rs/tests/integration_tests/test_pgvecto_rs.py @@ -1,11 +1,10 @@ -from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_pgvecto_rs.pgvecto_rs import PGVectoRS, PgvectoRSConfig + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, get_example_text, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class PGVectoRSVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py b/api/providers/vdb/vdb-pgvecto-rs/tests/unit_tests/test_pgvecto_rs.py similarity index 98% rename from api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py rename to api/providers/vdb/vdb-pgvecto-rs/tests/unit_tests/test_pgvecto_rs.py index 5b9ec8002a..c3291f7f12 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvecto_rs/test_pgvecto_rs.py +++ b/api/providers/vdb/vdb-pgvecto-rs/tests/unit_tests/test_pgvecto_rs.py @@ -83,8 +83,8 @@ def pgvecto_module(monkeypatch): for name, module in _build_fake_pgvecto_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.pgvecto_rs.collection as collection_module - import core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs as module + import dify_vdb_pgvecto_rs.collection as collection_module + import dify_vdb_pgvecto_rs.pgvecto_rs as module return importlib.reload(module), importlib.reload(collection_module) diff --git a/api/providers/vdb/vdb-pgvector/pyproject.toml b/api/providers/vdb/vdb-pgvector/pyproject.toml new file mode 100644 index 0000000000..2a972aa277 --- /dev/null +++ b/api/providers/vdb/vdb-pgvector/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-pgvector" +version = "0.0.1" + +dependencies = [ + "pgvector==0.4.2", +] +description = "Dify vector store backend (dify-vdb-pgvector)." + +[project.entry-points."dify.vector_backends"] +pgvector = "dify_vdb_pgvector.pgvector:PGVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/upstash/__init__.py b/api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/upstash/__init__.py rename to api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/__init__.py diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/pgvector.py similarity index 99% rename from api/core/rag/datasource/vdb/pgvector/pgvector.py rename to api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/pgvector.py index 0615b8312c..b1bdce0ad4 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/providers/vdb/vdb-pgvector/src/dify_vdb_pgvector/pgvector.py @@ -34,7 +34,7 @@ class PGVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config PGVECTOR_HOST is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py b/api/providers/vdb/vdb-pgvector/tests/integration_tests/test_pgvector.py similarity index 73% rename from api/tests/integration_tests/vdb/pgvector/test_pgvector.py rename to api/providers/vdb/vdb-pgvector/tests/integration_tests/test_pgvector.py index 4fdeca5a3a..974657510e 100644 --- a/api/tests/integration_tests/vdb/pgvector/test_pgvector.py +++ b/api/providers/vdb/vdb-pgvector/tests/integration_tests/test_pgvector.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.pgvector.pgvector import PGVector, PGVectorConfig -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_pgvector.pgvector import PGVector, PGVectorConfig + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class PGVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py b/api/providers/vdb/vdb-pgvector/tests/unit_tests/test_pgvector.py similarity index 92% rename from api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py rename to api/providers/vdb/vdb-pgvector/tests/unit_tests/test_pgvector.py index 7505262eb7..99a6e00c16 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/test_pgvector.py +++ b/api/providers/vdb/vdb-pgvector/tests/unit_tests/test_pgvector.py @@ -2,13 +2,10 @@ from contextlib import contextmanager from types import SimpleNamespace from unittest.mock import MagicMock, patch +import dify_vdb_pgvector.pgvector as pgvector_module import pytest +from dify_vdb_pgvector.pgvector import PGVector, PGVectorConfig -import core.rag.datasource.vdb.pgvector.pgvector as pgvector_module -from core.rag.datasource.vdb.pgvector.pgvector import ( - PGVector, - PGVectorConfig, -) from core.rag.models.document import Document @@ -26,7 +23,7 @@ class TestPGVector: ) self.collection_name = "test_collection" - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") def test_init(self, mock_pool_class): """Test PGVector initialization.""" mock_pool = MagicMock() @@ -41,7 +38,7 @@ class TestPGVector: assert pgvector.pg_bigm is False assert pgvector.index_hash is not None - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") def test_init_with_pg_bigm(self, mock_pool_class): """Test PGVector initialization with pg_bigm enabled.""" config = PGVectorConfig( @@ -61,8 +58,8 @@ class TestPGVector: assert pgvector.pg_bigm is True - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_basic(self, mock_redis, mock_pool_class): """Test basic collection creation.""" # Mock Redis operations @@ -104,8 +101,8 @@ class TestPGVector: # Verify Redis cache was set mock_redis.set.assert_called_once() - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_with_large_dimension(self, mock_redis, mock_pool_class): """Test collection creation with dimension > 2000 (no HNSW index).""" # Mock Redis operations @@ -139,8 +136,8 @@ class TestPGVector: hnsw_index_calls = [call for call in mock_cursor.execute.call_args_list if "hnsw" in str(call)] assert len(hnsw_index_calls) == 0 - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_with_pg_bigm(self, mock_redis, mock_pool_class): """Test collection creation with pg_bigm enabled.""" config = PGVectorConfig( @@ -180,8 +177,8 @@ class TestPGVector: bigm_index_calls = [call for call in mock_cursor.execute.call_args_list if "gin_bigm_ops" in str(call)] assert len(bigm_index_calls) == 1 - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_creates_vector_extension(self, mock_redis, mock_pool_class): """Test that vector extension is created if it doesn't exist.""" # Mock Redis operations @@ -213,8 +210,8 @@ class TestPGVector: ] assert len(create_extension_calls) == 1 - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_with_cache_hit(self, mock_redis, mock_pool_class): """Test that collection creation is skipped when cache exists.""" # Mock Redis operations - cache exists @@ -240,8 +237,8 @@ class TestPGVector: # Check that no SQL was executed (early return due to cache) assert mock_cursor.execute.call_count == 0 - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") - @patch("core.rag.datasource.vdb.pgvector.pgvector.redis_client") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.redis_client") def test_create_collection_with_redis_lock(self, mock_redis, mock_pool_class): """Test that Redis lock is used during collection creation.""" # Mock Redis operations @@ -273,7 +270,7 @@ class TestPGVector: mock_lock.__enter__.assert_called_once() mock_lock.__exit__.assert_called_once() - @patch("core.rag.datasource.vdb.pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") + @patch("dify_vdb_pgvector.pgvector.psycopg2.pool.SimpleConnectionPool") def test_get_cursor_context_manager(self, mock_pool_class): """Test that _get_cursor properly manages connection lifecycle.""" mock_pool = MagicMock() diff --git a/api/providers/vdb/vdb-qdrant/pyproject.toml b/api/providers/vdb/vdb-qdrant/pyproject.toml new file mode 100644 index 0000000000..6dd0b9560b --- /dev/null +++ b/api/providers/vdb/vdb-qdrant/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-qdrant" +version = "0.0.1" + +dependencies = [ + "qdrant-client==1.9.0", +] +description = "Dify vector store backend (dify-vdb-qdrant)." + +[project.entry-points."dify.vector_backends"] +qdrant = "dify_vdb_qdrant.qdrant_vector:QdrantVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/vikingdb/__init__.py b/api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/vikingdb/__init__.py rename to api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/__init__.py diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/qdrant_vector.py similarity index 100% rename from api/core/rag/datasource/vdb/qdrant/qdrant_vector.py rename to api/providers/vdb/vdb-qdrant/src/dify_vdb_qdrant/qdrant_vector.py diff --git a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py b/api/providers/vdb/vdb-qdrant/tests/integration_tests/test_qdrant.py similarity index 95% rename from api/tests/integration_tests/vdb/qdrant/test_qdrant.py rename to api/providers/vdb/vdb-qdrant/tests/integration_tests/test_qdrant.py index 709cc2e14e..e0badeb5de 100644 --- a/api/tests/integration_tests/vdb/qdrant/test_qdrant.py +++ b/api/providers/vdb/vdb-qdrant/tests/integration_tests/test_qdrant.py @@ -1,12 +1,11 @@ import uuid -from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector -from core.rag.models.document import Document -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_qdrant.qdrant_vector import QdrantConfig, QdrantVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) - -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +from core.rag.models.document import Document class QdrantVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py b/api/providers/vdb/vdb-qdrant/tests/unit_tests/test_qdrant_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py rename to api/providers/vdb/vdb-qdrant/tests/unit_tests/test_qdrant_vector.py index 0408506563..0ed5491fbe 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/qdrant/test_qdrant_vector.py +++ b/api/providers/vdb/vdb-qdrant/tests/unit_tests/test_qdrant_vector.py @@ -125,7 +125,7 @@ def qdrant_module(monkeypatch): for name, module in _build_fake_qdrant_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.qdrant.qdrant_vector as module + import dify_vdb_qdrant.qdrant_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-relyt/pyproject.toml b/api/providers/vdb/vdb-relyt/pyproject.toml new file mode 100644 index 0000000000..2a7c7fac87 --- /dev/null +++ b/api/providers/vdb/vdb-relyt/pyproject.toml @@ -0,0 +1,12 @@ +[project] +name = "dify-vdb-relyt" +version = "0.0.1" + +dependencies = [] +description = "Dify vector store backend (dify-vdb-relyt)." + +[project.entry-points."dify.vector_backends"] +relyt = "dify_vdb_relyt.relyt_vector:RelytVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/core/rag/datasource/vdb/weaviate/__init__.py b/api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/__init__.py similarity index 100% rename from api/core/rag/datasource/vdb/weaviate/__init__.py rename to api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/__init__.py diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/relyt_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/relyt/relyt_vector.py rename to api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/relyt_vector.py index 64b45bf28b..336c2d3c8a 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/providers/vdb/vdb-relyt/src/dify_vdb_relyt/relyt_vector.py @@ -38,7 +38,7 @@ class RelytConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config RELYT_HOST is required") if not values["port"]: @@ -239,7 +239,7 @@ class RelytVector(BaseVector): self, embedding: list[float], k: int = 4, - filter: dict | None = None, + filter: dict[str, Any] | None = None, ) -> list[tuple[Document, float]]: # Add the filter if provided diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py b/api/providers/vdb/vdb-relyt/tests/unit_tests/test_relyt_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py rename to api/providers/vdb/vdb-relyt/tests/unit_tests/test_relyt_vector.py index 43cdb4948d..f97ad1400a 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/relyt/test_relyt_vector.py +++ b/api/providers/vdb/vdb-relyt/tests/unit_tests/test_relyt_vector.py @@ -63,7 +63,7 @@ def relyt_module(monkeypatch): for name, module in _build_fake_relyt_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.relyt.relyt_vector as module + import dify_vdb_relyt.relyt_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-tablestore/pyproject.toml b/api/providers/vdb/vdb-tablestore/pyproject.toml new file mode 100644 index 0000000000..fd1a2d54e0 --- /dev/null +++ b/api/providers/vdb/vdb-tablestore/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-tablestore" +version = "0.0.1" + +dependencies = [ + "tablestore==6.4.4", +] +description = "Dify vector store backend (dify-vdb-tablestore)." + +[project.entry-points."dify.vector_backends"] +tablestore = "dify_vdb_tablestore.tablestore_vector:TableStoreVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/__mock/__init__.py b/api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/__mock/__init__.py rename to api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/__init__.py diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/tablestore_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/tablestore/tablestore_vector.py rename to api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/tablestore_vector.py index 4a734232ec..f9deac11e5 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/providers/vdb/vdb-tablestore/src/dify_vdb_tablestore/tablestore_vector.py @@ -30,7 +30,7 @@ class TableStoreConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["access_key_id"]: raise ValueError("config ACCESS_KEY_ID is required") if not values["access_key_secret"]: diff --git a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py b/api/providers/vdb/vdb-tablestore/tests/integration_tests/test_tablestore.py similarity index 93% rename from api/tests/integration_tests/vdb/tablestore/test_tablestore.py rename to api/providers/vdb/vdb-tablestore/tests/integration_tests/test_tablestore.py index b60e26a881..97c9626ee1 100644 --- a/api/tests/integration_tests/vdb/tablestore/test_tablestore.py +++ b/api/providers/vdb/vdb-tablestore/tests/integration_tests/test_tablestore.py @@ -1,20 +1,21 @@ +import logging import os import uuid import tablestore from _pytest.python_api import approx - -from core.rag.datasource.vdb.tablestore.tablestore_vector import ( +from dify_vdb_tablestore.tablestore_vector import ( TableStoreConfig, TableStoreVector, ) -from tests.integration_tests.vdb.test_vector_store import ( + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, get_example_document, get_example_text, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +logger = logging.getLogger(__name__) class TableStoreVectorTest(AbstractVectorTest): @@ -90,7 +91,7 @@ class TableStoreVectorTest(AbstractVectorTest): try: self.vector.delete() except Exception: - pass + logger.debug("Failed to delete vector store during test setup, it may not exist yet") return super().run_all_tests() diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py b/api/providers/vdb/vdb-tablestore/tests/unit_tests/test_tablestore_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py rename to api/providers/vdb/vdb-tablestore/tests/unit_tests/test_tablestore_vector.py index e3b6676d9b..62a11e0445 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/tablestore/test_tablestore_vector.py +++ b/api/providers/vdb/vdb-tablestore/tests/unit_tests/test_tablestore_vector.py @@ -81,7 +81,7 @@ def tablestore_module(monkeypatch): fake_module = _build_fake_tablestore_module() monkeypatch.setitem(sys.modules, "tablestore", fake_module) - import core.rag.datasource.vdb.tablestore.tablestore_vector as module + import dify_vdb_tablestore.tablestore_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-tencent/pyproject.toml b/api/providers/vdb/vdb-tencent/pyproject.toml new file mode 100644 index 0000000000..7bb761b169 --- /dev/null +++ b/api/providers/vdb/vdb-tencent/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-tencent" +version = "0.0.1" + +dependencies = [ + "tcvectordb~=2.1.0", +] +description = "Dify vector store backend (dify-vdb-tencent)." + +[project.entry-points."dify.vector_backends"] +tencent = "dify_vdb_tencent.tencent_vector:TencentVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/analyticdb/__init__.py b/api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/analyticdb/__init__.py rename to api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/__init__.py diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/tencent_vector.py similarity index 100% rename from api/core/rag/datasource/vdb/tencent/tencent_vector.py rename to api/providers/vdb/vdb-tencent/src/dify_vdb_tencent/tencent_vector.py diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/providers/vdb/vdb-tencent/tests/integration_tests/conftest.py similarity index 100% rename from api/tests/integration_tests/vdb/__mock/tcvectordb.py rename to api/providers/vdb/vdb-tencent/tests/integration_tests/conftest.py diff --git a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py b/api/providers/vdb/vdb-tencent/tests/integration_tests/test_tencent.py similarity index 76% rename from api/tests/integration_tests/vdb/tcvectordb/test_tencent.py rename to api/providers/vdb/vdb-tencent/tests/integration_tests/test_tencent.py index 3d6deff2a0..a53ec87f92 100644 --- a/api/tests/integration_tests/vdb/tcvectordb/test_tencent.py +++ b/api/providers/vdb/vdb-tencent/tests/integration_tests/test_tencent.py @@ -1,12 +1,8 @@ from unittest.mock import MagicMock -from core.rag.datasource.vdb.tencent.tencent_vector import TencentConfig, TencentVector -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text +from dify_vdb_tencent.tencent_vector import TencentConfig, TencentVector -pytest_plugins = ( - "tests.integration_tests.vdb.test_vector_store", - "tests.integration_tests.vdb.__mock.tcvectordb", -) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text mock_client = MagicMock() mock_client.list_databases.return_value = [{"name": "test"}] diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py b/api/providers/vdb/vdb-tencent/tests/unit_tests/test_tencent_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py rename to api/providers/vdb/vdb-tencent/tests/unit_tests/test_tencent_vector.py index d8f35a6019..299e40ee1e 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/tencent/test_tencent_vector.py +++ b/api/providers/vdb/vdb-tencent/tests/unit_tests/test_tencent_vector.py @@ -140,7 +140,7 @@ def tencent_module(monkeypatch): for name, module in _build_fake_tencent_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.tencent.tencent_vector as module + import dify_vdb_tencent.tencent_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/pyproject.toml b/api/providers/vdb/vdb-tidb-on-qdrant/pyproject.toml new file mode 100644 index 0000000000..5040fb38ba --- /dev/null +++ b/api/providers/vdb/vdb-tidb-on-qdrant/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-tidb-on-qdrant" +version = "0.0.1" + +dependencies = [ + "qdrant-client==1.9.0", +] +description = "Dify vector store backend (dify-vdb-tidb-on-qdrant)." + +[project.entry-points."dify.vector_backends"] +tidb_on_qdrant = "dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector:TidbOnQdrantVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/baidu/__init__.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/baidu/__init__.py rename to api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/__init__.py diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py similarity index 88% rename from api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py rename to api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py index 605cc5a08f..abca55f540 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -1,4 +1,5 @@ import json +import logging import os import uuid from collections.abc import Generator, Iterable, Sequence @@ -7,6 +8,8 @@ from typing import TYPE_CHECKING, Any import httpx import qdrant_client + +logger = logging.getLogger(__name__) from flask import current_app from httpx import DigestAuth from pydantic import BaseModel @@ -24,12 +27,12 @@ from sqlalchemy import select from configs import dify_config from core.rag.datasource.vdb.field import Field -from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document +from dify_vdb_tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import Dataset, TidbAuthBinding @@ -292,26 +295,27 @@ class TidbOnQdrantVector(BaseVector): if not ids: return - try: - filter = models.Filter( - must=[ - models.FieldCondition( - key="metadata.doc_id", - match=models.MatchAny(any=ids), - ), - ], - ) - self._client.delete( - collection_name=self._collection_name, - points_selector=FilterSelector(filter=filter), - ) - except UnexpectedResponse as e: - # Collection does not exist, so return - if e.status_code == 404: - return - # Some other error occurred, so re-raise the exception - else: - raise e + batch_size = 1000 + for i in range(0, len(ids), batch_size): + batch = ids[i : i + batch_size] + + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchAny(any=batch), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code != 404: + raise e def text_exists(self, id: str) -> bool: all_collection_name = [] @@ -420,13 +424,16 @@ class TidbOnQdrantVector(BaseVector): class TidbOnQdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: + logger.info("init_vector: tenant_id=%s, dataset_id=%s", dataset.tenant_id, dataset.id) stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) tidb_auth_binding = db.session.scalars(stmt).one_or_none() if not tidb_auth_binding: + logger.info("No existing TidbAuthBinding for tenant %s, acquiring lock", dataset.tenant_id) with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) tidb_auth_binding = db.session.scalars(stmt).one_or_none() if tidb_auth_binding: + logger.info("Found binding after lock: cluster_id=%s", tidb_auth_binding.cluster_id) TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" else: @@ -436,11 +443,18 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): .limit(1) ) if idle_tidb_auth_binding: + logger.info( + "Assigning idle cluster %s to tenant %s", + idle_tidb_auth_binding.cluster_id, + dataset.tenant_id, + ) idle_tidb_auth_binding.active = True idle_tidb_auth_binding.tenant_id = dataset.tenant_id db.session.commit() + tidb_auth_binding = idle_tidb_auth_binding TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" else: + logger.info("No idle clusters available, creating new cluster for tenant %s", dataset.tenant_id) new_cluster = TidbService.create_tidb_serverless_cluster( dify_config.TIDB_PROJECT_ID or "", dify_config.TIDB_API_URL or "", @@ -449,21 +463,39 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): dify_config.TIDB_PRIVATE_KEY or "", dify_config.TIDB_REGION or "", ) + logger.info( + "New cluster created: cluster_id=%s, qdrant_endpoint=%s", + new_cluster["cluster_id"], + new_cluster.get("qdrant_endpoint"), + ) new_tidb_auth_binding = TidbAuthBinding( cluster_id=new_cluster["cluster_id"], cluster_name=new_cluster["cluster_name"], account=new_cluster["account"], password=new_cluster["password"], + qdrant_endpoint=new_cluster.get("qdrant_endpoint"), tenant_id=dataset.tenant_id, active=True, status=TidbAuthBindingStatus.ACTIVE, ) db.session.add(new_tidb_auth_binding) db.session.commit() + tidb_auth_binding = new_tidb_auth_binding TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" else: + logger.info("Existing binding found: cluster_id=%s", tidb_auth_binding.cluster_id) TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + qdrant_url = ( + (tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None) or dify_config.TIDB_ON_QDRANT_URL or "" + ) + logger.info( + "Using qdrant endpoint: %s (from_binding=%s, fallback_global=%s)", + qdrant_url, + tidb_auth_binding.qdrant_endpoint if tidb_auth_binding else None, + dify_config.TIDB_ON_QDRANT_URL, + ) + if dataset.index_struct_dict: class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix @@ -478,7 +510,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): collection_name=collection_name, group_id=dataset.id, config=TidbOnQdrantConfig( - endpoint=dify_config.TIDB_ON_QDRANT_URL or "", + endpoint=qdrant_url, api_key=TIDB_ON_QDRANT_API_KEY, root_path=str(config.root_path), timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py similarity index 70% rename from api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py rename to api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py index 37114be6e7..6283dbb986 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py @@ -1,3 +1,4 @@ +import logging import time import uuid from collections.abc import Sequence @@ -12,6 +13,8 @@ from extensions.ext_redis import redis_client from models.dataset import TidbAuthBinding from models.enums import TidbAuthBindingStatus +logger = logging.getLogger(__name__) + # Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn _tidb_http_client: httpx.Client = get_pooled_http_client( "tidb:cloud", @@ -20,6 +23,46 @@ _tidb_http_client: httpx.Client = get_pooled_http_client( class TidbService: + @staticmethod + def extract_qdrant_endpoint(cluster_response: dict) -> str | None: + """Extract the qdrant endpoint URL from a Get Cluster API response. + + Reads ``endpoints.public.host`` (e.g. ``gateway01.xx.tidbcloud.com``), + prepends ``qdrant-`` and wraps it as an ``https://`` URL. + """ + endpoints = cluster_response.get("endpoints") or {} + public = endpoints.get("public") or {} + host = public.get("host") + if host: + return f"https://qdrant-{host}" + return None + + @staticmethod + def fetch_qdrant_endpoint(api_url: str, public_key: str, private_key: str, cluster_id: str) -> str | None: + """Call Get Cluster API and extract the qdrant endpoint. + + Use ``extract_qdrant_endpoint`` instead when you already have + the cluster response to avoid a redundant API call. + """ + try: + logger.info("Fetching qdrant endpoint for cluster %s", cluster_id) + cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) + if not cluster_response: + logger.warning("Empty response from Get Cluster API for cluster %s", cluster_id) + return None + qdrant_url = TidbService.extract_qdrant_endpoint(cluster_response) + if qdrant_url: + logger.info("Resolved qdrant endpoint for cluster %s: %s", cluster_id, qdrant_url) + return qdrant_url + logger.warning( + "No endpoints.public.host found for cluster %s, response keys: %s", + cluster_id, + list(cluster_response.keys()), + ) + except Exception: + logger.exception("Failed to fetch qdrant endpoint for cluster %s", cluster_id) + return None + @staticmethod def create_tidb_serverless_cluster( project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str @@ -57,6 +100,7 @@ class TidbService: "rootPassword": password, } + logger.info("Creating TiDB serverless cluster: display_name=%s, region=%s", display_name, region) response = _tidb_http_client.post( f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key) ) @@ -64,21 +108,39 @@ class TidbService: if response.status_code == 200: response_data = response.json() cluster_id = response_data["clusterId"] + logger.info("Cluster created, cluster_id=%s, waiting for ACTIVE state", cluster_id) retry_count = 0 max_retries = 30 while retry_count < max_retries: cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) if cluster_response["state"] == "ACTIVE": user_prefix = cluster_response["userPrefix"] + qdrant_endpoint = TidbService.extract_qdrant_endpoint(cluster_response) + logger.info( + "Cluster %s is ACTIVE, user_prefix=%s, qdrant_endpoint=%s", + cluster_id, + user_prefix, + qdrant_endpoint, + ) return { "cluster_id": cluster_id, "cluster_name": display_name, "account": f"{user_prefix}.root", "password": password, + "qdrant_endpoint": qdrant_endpoint, } - time.sleep(30) # wait 30 seconds before retrying + logger.info( + "Cluster %s state=%s, retry %d/%d", + cluster_id, + cluster_response["state"], + retry_count + 1, + max_retries, + ) + time.sleep(30) retry_count += 1 + logger.error("Cluster %s did not become ACTIVE after %d retries", cluster_id, max_retries) else: + logger.error("Failed to create cluster: status=%d, body=%s", response.status_code, response.text) response.raise_for_status() @staticmethod @@ -184,8 +246,18 @@ class TidbService: userPrefix = item["userPrefix"] if state == "ACTIVE" and len(userPrefix) > 0: cluster_info = tidb_serverless_list_map[item["clusterId"]] - cluster_info.status = TidbAuthBindingStatus.ACTIVE cluster_info.account = f"{userPrefix}.root" + if not cluster_info.qdrant_endpoint: + cluster_info.qdrant_endpoint = TidbService.extract_qdrant_endpoint( + item + ) or TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"]) + if cluster_info.qdrant_endpoint: + cluster_info.status = TidbAuthBindingStatus.ACTIVE + else: + logger.warning( + "Cluster %s is ACTIVE but qdrant endpoint is not ready; will retry later", + item["clusterId"], + ) db.session.add(cluster_info) db.session.commit() else: @@ -243,19 +315,29 @@ class TidbService: if response.status_code == 200: response_data = response.json() cluster_infos = [] + logger.info("Batch created %d clusters", len(response_data.get("clusters", []))) for item in response_data["clusters"]: cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" cached_password = redis_client.get(cache_key) if not cached_password: + logger.warning("No cached password for cluster %s, skipping", item["displayName"]) continue + qdrant_endpoint = TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"]) + logger.info( + "Batch cluster %s: qdrant_endpoint=%s", + item["clusterId"], + qdrant_endpoint, + ) cluster_info = { "cluster_id": item["clusterId"], "cluster_name": item["displayName"], "account": "root", "password": cached_password.decode("utf-8"), + "qdrant_endpoint": qdrant_endpoint, } cluster_infos.append(cluster_info) return cluster_infos else: + logger.error("Batch create failed: status=%d, body=%s", response.status_code, response.text) response.raise_for_status() return [] diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py similarity index 64% rename from api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py rename to api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py index c25af79ae4..76802de62e 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_on_qdrant_vector.py @@ -2,13 +2,12 @@ from unittest.mock import patch import httpx import pytest -from qdrant_client.http import models as rest -from qdrant_client.http.exceptions import UnexpectedResponse - -from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import ( +from dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector import ( TidbOnQdrantConfig, TidbOnQdrantVector, ) +from qdrant_client.http import models as rest +from qdrant_client.http.exceptions import UnexpectedResponse class TestTidbOnQdrantVectorDeleteByIds: @@ -22,7 +21,7 @@ class TestTidbOnQdrantVectorDeleteByIds: api_key="test_api_key", ) - with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"): + with patch("dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"): vector = TidbOnQdrantVector( collection_name="test_collection", group_id="test_group", @@ -115,14 +114,12 @@ class TestTidbOnQdrantVectorDeleteByIds: assert exc_info.value.status_code == 500 - def test_delete_by_ids_with_large_batch(self, vector_instance): - """Test deletion with a large batch of IDs.""" - # Create 1000 IDs + def test_delete_by_ids_with_exactly_1000(self, vector_instance): + """Test deletion with exactly 1000 IDs triggers a single batch.""" ids = [f"doc_{i}" for i in range(1000)] vector_instance.delete_by_ids(ids) - # Verify single delete call with all IDs vector_instance._client.delete.assert_called_once() call_args = vector_instance._client.delete.call_args @@ -130,11 +127,28 @@ class TestTidbOnQdrantVectorDeleteByIds: filter_obj = filter_selector.filter field_condition = filter_obj.must[0] - # Verify all 1000 IDs are in the batch assert len(field_condition.match.any) == 1000 assert "doc_0" in field_condition.match.any assert "doc_999" in field_condition.match.any + def test_delete_by_ids_splits_into_batches(self, vector_instance): + """Test deletion with >1000 IDs triggers multiple batched calls.""" + ids = [f"doc_{i}" for i in range(2500)] + + vector_instance.delete_by_ids(ids) + + assert vector_instance._client.delete.call_count == 3 + + batches = [] + for call in vector_instance._client.delete.call_args_list: + filter_selector = call[1]["points_selector"] + field_condition = filter_selector.filter.must[0] + batches.append(field_condition.match.any) + + assert len(batches[0]) == 1000 + assert len(batches[1]) == 1000 + assert len(batches[2]) == 500 + def test_delete_by_ids_filter_structure(self, vector_instance): """Test that the filter structure is correctly constructed.""" ids = ["doc1", "doc2"] @@ -158,3 +172,57 @@ class TestTidbOnQdrantVectorDeleteByIds: # Verify MatchAny structure assert isinstance(field_condition.match, rest.MatchAny) assert field_condition.match.any == ids + + +class TestInitVectorEndpointSelection: + """Test that init_vector selects the correct qdrant endpoint. + + We avoid importing the full module (which triggers Flask app context) + by testing the endpoint selection logic directly on TidbOnQdrantConfig. + """ + + def test_uses_binding_endpoint_when_present(self): + binding_endpoint = "https://qdrant-custom.tidb.com" + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-custom.tidb.com" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "https://qdrant-custom.tidb.com" + + def test_falls_back_to_global_when_binding_endpoint_is_none(self): + binding_endpoint = None + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-global.tidb.com" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "https://qdrant-global.tidb.com" + + def test_falls_back_to_empty_when_both_none(self): + binding_endpoint = None + global_url = None + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "" + config = TidbOnQdrantConfig(endpoint=qdrant_url) + assert config.endpoint == "" + + def test_binding_endpoint_takes_precedence_over_global(self): + binding_endpoint = "https://qdrant-ap-southeast.tidb.com" + global_url = "https://qdrant-us-east.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-ap-southeast.tidb.com" + + def test_empty_string_binding_endpoint_falls_back_to_global(self): + binding_endpoint = "" + global_url = "https://qdrant-global.tidb.com" + + qdrant_url = binding_endpoint or global_url or "" + + assert qdrant_url == "https://qdrant-global.tidb.com" diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py new file mode 100644 index 0000000000..20a42f6cc3 --- /dev/null +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py @@ -0,0 +1,304 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from dify_vdb_tidb_on_qdrant.tidb_service import TidbService + +from models.enums import TidbAuthBindingStatus + + +class TestExtractQdrantEndpoint: + """Unit tests for TidbService.extract_qdrant_endpoint.""" + + def test_returns_endpoint_when_host_present(self): + response = {"endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}}} + result = TidbService.extract_qdrant_endpoint(response) + assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com" + + def test_returns_none_when_host_missing(self): + response = {"endpoints": {"public": {}}} + assert TidbService.extract_qdrant_endpoint(response) is None + + def test_returns_none_when_public_missing(self): + response = {"endpoints": {}} + assert TidbService.extract_qdrant_endpoint(response) is None + + def test_returns_none_when_endpoints_missing(self): + assert TidbService.extract_qdrant_endpoint({}) is None + + +class TestFetchQdrantEndpoint: + """Unit tests for TidbService.fetch_qdrant_endpoint.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_endpoint_when_host_present(self, mock_get_cluster): + mock_get_cluster.return_value = { + "endpoints": {"public": {"host": "gateway01.us-east-1.tidbcloud.com", "port": 4000}} + } + result = TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") + assert result == "https://qdrant-gateway01.us-east-1.tidbcloud.com" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_cluster_response_is_none(self, mock_get_cluster): + mock_get_cluster.return_value = None + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_host_missing(self, mock_get_cluster): + mock_get_cluster.return_value = {"endpoints": {"public": {}}} + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_when_endpoints_missing(self, mock_get_cluster): + mock_get_cluster.return_value = {} + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + @patch.object(TidbService, "get_tidb_serverless_cluster") + def test_returns_none_on_exception(self, mock_get_cluster): + mock_get_cluster.side_effect = RuntimeError("network error") + assert TidbService.fetch_qdrant_endpoint("url", "pub", "priv", "c-123") is None + + +class TestCreateTidbServerlessClusterQdrantEndpoint: + """Verify that create_tidb_serverless_cluster includes qdrant_endpoint in its result.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = { + "state": "ACTIVE", + "userPrefix": "pfx", + "endpoints": {"public": {"host": "gw.tidbcloud.com", "port": 4000}}, + } + + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_result_qdrant_endpoint_none_when_no_endpoints(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = {"state": "ACTIVE", "userPrefix": "pfx"} + + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] is None + + +class TestBatchCreateTidbServerlessClusterQdrantEndpoint: + """Verify that batch_create includes qdrant_endpoint per cluster.""" + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_batch_result_contains_qdrant_endpoint(self, mock_config, mock_http, mock_redis, mock_fetch_ep): + mock_config.TIDB_SPEND_LIMIT = 10 + cluster_name = "abc123" + mock_http.post.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": cluster_name}]}, + ) + mock_redis.setex = MagicMock() + mock_redis.get.return_value = b"password123" + + result = TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) + + assert len(result) == 1 + assert result[0]["qdrant_endpoint"] == "https://qdrant-gw.tidbcloud.com" + + +class TestCreateTidbServerlessClusterRetry: + """Cover retry/logging paths in create_tidb_serverless_cluster.""" + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_polls_until_active(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.side_effect = [ + {"state": "CREATING", "userPrefix": ""}, + {"state": "ACTIVE", "userPrefix": "pfx", "endpoints": {"public": {"host": "gw.tidb.com"}}}, + ] + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is not None + assert result["qdrant_endpoint"] == "https://qdrant-gw.tidb.com" + assert mock_get_cluster.call_count == 2 + + @patch.object(TidbService, "get_tidb_serverless_cluster") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_returns_none_after_max_retries(self, mock_config, mock_http, mock_get_cluster): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock(status_code=200, json=lambda: {"clusterId": "c-1"}) + mock_get_cluster.return_value = {"state": "CREATING", "userPrefix": ""} + + with patch("dify_vdb_tidb_on_qdrant.tidb_service.time.sleep"): + result = TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + assert result is None + + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=400, text="Bad Request") + mock_response.raise_for_status.side_effect = Exception("HTTP 400") + mock_http.post.return_value = mock_response + + with pytest.raises(Exception, match="HTTP 400"): + TidbService.create_tidb_serverless_cluster("proj", "url", "iam", "pub", "priv", "us-east-1") + + +class TestBatchCreateEdgeCases: + """Cover logging/edge-case branches in batch_create.""" + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None) + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_skips_cluster_when_no_cached_password(self, mock_config, mock_http, mock_redis, mock_fetch_ep): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_http.post.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "displayName": "name1"}]}, + ) + mock_redis.setex = MagicMock() + mock_redis.get.return_value = None + + result = TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) + + assert len(result) == 0 + mock_fetch_ep.assert_not_called() + + @patch("dify_vdb_tidb_on_qdrant.tidb_service.redis_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.dify_config") + def test_raises_on_post_failure(self, mock_config, mock_http, mock_redis): + mock_config.TIDB_SPEND_LIMIT = 10 + mock_response = MagicMock(status_code=500, text="Server Error") + mock_response.raise_for_status.side_effect = Exception("HTTP 500") + mock_http.post.return_value = mock_response + mock_redis.setex = MagicMock() + + with pytest.raises(Exception, match="HTTP 500"): + TidbService.batch_create_tidb_serverless_cluster( + batch_size=1, + project_id="proj", + api_url="url", + iam_url="iam", + public_key="pub", + private_key="priv", + region="us-east-1", + ) + + +class TestBatchUpdateTidbServerlessClusterStatus: + """Verify that status updates only expose clusters after qdrant endpoint is ready.""" + + @patch("dify_vdb_tidb_on_qdrant.tidb_service.db") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + def test_sets_active_when_batch_response_contains_endpoint(self, mock_http, mock_db): + binding = SimpleNamespace( + cluster_id="c-1", + status=TidbAuthBindingStatus.CREATING, + account="root", + qdrant_endpoint=None, + ) + mock_http.get.return_value = MagicMock( + status_code=200, + json=lambda: { + "clusters": [ + { + "clusterId": "c-1", + "state": "ACTIVE", + "userPrefix": "pfx", + "endpoints": {"public": {"host": "gw.tidbcloud.com"}}, + } + ] + }, + ) + + TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv") + + assert binding.account == "pfx.root" + assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com" + assert binding.status == TidbAuthBindingStatus.ACTIVE + mock_db.session.add.assert_called_once_with(binding) + mock_db.session.commit.assert_called_once() + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.db") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + def test_fetches_endpoint_when_batch_response_omits_it(self, mock_http, mock_db, mock_fetch_endpoint): + binding = SimpleNamespace( + cluster_id="c-1", + status=TidbAuthBindingStatus.CREATING, + account="root", + qdrant_endpoint=None, + ) + mock_http.get.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]}, + ) + + TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv") + + assert binding.account == "pfx.root" + assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com" + assert binding.status == TidbAuthBindingStatus.ACTIVE + mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1") + mock_db.session.add.assert_called_once_with(binding) + mock_db.session.commit.assert_called_once() + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None) + @patch("dify_vdb_tidb_on_qdrant.tidb_service.db") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + def test_keeps_creating_when_endpoint_is_not_ready(self, mock_http, mock_db, mock_fetch_endpoint): + binding = SimpleNamespace( + cluster_id="c-1", + status=TidbAuthBindingStatus.CREATING, + account="root", + qdrant_endpoint=None, + ) + mock_http.get.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]}, + ) + + TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv") + + assert binding.account == "pfx.root" + assert binding.qdrant_endpoint is None + assert binding.status == TidbAuthBindingStatus.CREATING + mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1") + mock_db.session.add.assert_called_once_with(binding) + mock_db.session.commit.assert_called_once() diff --git a/api/providers/vdb/vdb-tidb-vector/pyproject.toml b/api/providers/vdb/vdb-tidb-vector/pyproject.toml new file mode 100644 index 0000000000..0e2f0ad88f --- /dev/null +++ b/api/providers/vdb/vdb-tidb-vector/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-tidb-vector" +version = "0.0.1" + +dependencies = [ + "tidb-vector==0.0.15", +] +description = "Dify vector store backend (dify-vdb-tidb-vector)." + +[project.entry-points."dify.vector_backends"] +tidb_vector = "dify_vdb_tidb_vector.tidb_vector:TiDBVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/chroma/__init__.py b/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/chroma/__init__.py rename to api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/__init__.py diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py rename to api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py index e321681093..c696a685dd 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/providers/vdb/vdb-tidb-vector/src/dify_vdb_tidb_vector/tidb_vector.py @@ -31,7 +31,7 @@ class TiDBVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config TIDB_VECTOR_HOST is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py b/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/check_tiflash_ready.py similarity index 72% rename from api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py rename to api/providers/vdb/vdb-tidb-vector/tests/integration_tests/check_tiflash_ready.py index f76700aa0e..97f8406e42 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py +++ b/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/check_tiflash_ready.py @@ -1,9 +1,13 @@ +import logging import time import pymysql +logger = logging.getLogger(__name__) + def check_tiflash_ready() -> bool: + connection = None try: connection = pymysql.connect( host="localhost", @@ -23,8 +27,8 @@ def check_tiflash_ready() -> bool: cursor.execute(select_tiflash_query) result = cursor.fetchall() return result is not None and len(result) > 0 - except Exception as e: - print(f"TiFlash is not ready. Exception: {e}") + except Exception: + logger.exception("TiFlash is not ready.") return False finally: if connection: @@ -38,20 +42,20 @@ def main(): for attempt in range(max_attempts): try: is_tiflash_ready = check_tiflash_ready() - except Exception as e: - print(f"TiFlash is not ready. Exception: {e}") + except Exception: + logger.exception("TiFlash is not ready.") is_tiflash_ready = False if is_tiflash_ready: break else: - print(f"Attempt {attempt + 1} failed, retry in {retry_interval_seconds} seconds...") + logger.error("Attempt %s failed, retry in %s seconds...", attempt + 1, retry_interval_seconds) time.sleep(retry_interval_seconds) if is_tiflash_ready: - print("TiFlash is ready in TiDB.") + logger.info("TiFlash is ready in TiDB.") else: - print(f"TiFlash is not ready in TiDB after {max_attempts} attempting checks.") + logger.error("TiFlash is not ready in TiDB after %s attempting checks.", max_attempts) exit(1) diff --git a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py b/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/test_tidb_vector.py similarity index 77% rename from api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py rename to api/providers/vdb/vdb-tidb-vector/tests/integration_tests/test_tidb_vector.py index 14c6d1c67c..ac854acbf9 100644 --- a/api/tests/integration_tests/vdb/tidb_vector/test_tidb_vector.py +++ b/api/providers/vdb/vdb-tidb-vector/tests/integration_tests/test_tidb_vector.py @@ -1,10 +1,8 @@ import pytest +from dify_vdb_tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig -from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVector, TiDBVectorConfig -from models.dataset import Document -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text - -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text +from core.rag.models.document import Document @pytest.fixture diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py b/api/providers/vdb/vdb-tidb-vector/tests/unit_tests/test_tidb_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py rename to api/providers/vdb/vdb-tidb-vector/tests/unit_tests/test_tidb_vector.py index 8e19a59af8..bdbed2f740 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_vector/test_tidb_vector.py +++ b/api/providers/vdb/vdb-tidb-vector/tests/unit_tests/test_tidb_vector.py @@ -12,7 +12,7 @@ from core.rag.models.document import Document @pytest.fixture def tidb_module(): - import core.rag.datasource.vdb.tidb_vector.tidb_vector as module + import dify_vdb_tidb_vector.tidb_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-upstash/pyproject.toml b/api/providers/vdb/vdb-upstash/pyproject.toml new file mode 100644 index 0000000000..f71773cdbb --- /dev/null +++ b/api/providers/vdb/vdb-upstash/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-upstash" +version = "0.0.1" + +dependencies = [ + "upstash-vector==0.8.0", +] +description = "Dify vector store backend (dify-vdb-upstash)." + +[project.entry-points."dify.vector_backends"] +upstash = "dify_vdb_upstash.upstash_vector:UpstashVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/couchbase/__init__.py b/api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/couchbase/__init__.py rename to api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/__init__.py diff --git a/api/core/rag/datasource/vdb/upstash/upstash_vector.py b/api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/upstash_vector.py similarity index 98% rename from api/core/rag/datasource/vdb/upstash/upstash_vector.py rename to api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/upstash_vector.py index 289d971853..75d70a1964 100644 --- a/api/core/rag/datasource/vdb/upstash/upstash_vector.py +++ b/api/providers/vdb/vdb-upstash/src/dify_vdb_upstash/upstash_vector.py @@ -20,7 +20,7 @@ class UpstashVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["url"]: raise ValueError("Upstash URL is required") if not values["token"]: diff --git a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py b/api/providers/vdb/vdb-upstash/tests/integration_tests/conftest.py similarity index 94% rename from api/tests/integration_tests/vdb/__mock/upstashvectordb.py rename to api/providers/vdb/vdb-upstash/tests/integration_tests/conftest.py index 70c85d4c98..adba0c150c 100644 --- a/api/tests/integration_tests/vdb/__mock/upstashvectordb.py +++ b/api/providers/vdb/vdb-upstash/tests/integration_tests/conftest.py @@ -6,7 +6,6 @@ from _pytest.monkeypatch import MonkeyPatch from upstash_vector import Index -# Mocking the Index class from upstash_vector class MockIndex: def __init__(self, url="", token=""): self.url = url @@ -37,7 +36,6 @@ class MockIndex: namespace: str = "", include_data: bool = False, ): - # Simple mock query, in real scenario you would calculate similarity mock_result = [] for vector_data in self.vectors: mock_result.append(vector_data) diff --git a/api/tests/integration_tests/vdb/upstash/test_upstash_vector.py b/api/providers/vdb/vdb-upstash/tests/integration_tests/test_upstash_vector.py similarity index 75% rename from api/tests/integration_tests/vdb/upstash/test_upstash_vector.py rename to api/providers/vdb/vdb-upstash/tests/integration_tests/test_upstash_vector.py index 8cea0a05eb..f4a65030b6 100644 --- a/api/tests/integration_tests/vdb/upstash/test_upstash_vector.py +++ b/api/providers/vdb/vdb-upstash/tests/integration_tests/test_upstash_vector.py @@ -1,8 +1,7 @@ -from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVector, UpstashVectorConfig -from core.rag.models.document import Document -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text +from dify_vdb_upstash.upstash_vector import UpstashVector, UpstashVectorConfig -pytest_plugins = ("tests.integration_tests.vdb.__mock.upstashvectordb",) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text +from core.rag.models.document import Document class UpstashVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py b/api/providers/vdb/vdb-upstash/tests/unit_tests/test_upstash_vector.py similarity index 97% rename from api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py rename to api/providers/vdb/vdb-upstash/tests/unit_tests/test_upstash_vector.py index ac8a63a44b..a884275c89 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/upstash/test_upstash_vector.py +++ b/api/providers/vdb/vdb-upstash/tests/unit_tests/test_upstash_vector.py @@ -38,11 +38,11 @@ def _build_fake_upstash_module(): @pytest.fixture def upstash_module(monkeypatch): # Remove patched modules if present - for modname in ["upstash_vector", "core.rag.datasource.vdb.upstash.upstash_vector"]: + for modname in ["upstash_vector", "dify_vdb_upstash.upstash_vector"]: if modname in sys.modules: monkeypatch.delitem(sys.modules, modname, raising=False) monkeypatch.setitem(sys.modules, "upstash_vector", _build_fake_upstash_module()) - module = importlib.import_module("core.rag.datasource.vdb.upstash.upstash_vector") + module = importlib.import_module("dify_vdb_upstash.upstash_vector") return module diff --git a/api/providers/vdb/vdb-vastbase/pyproject.toml b/api/providers/vdb/vdb-vastbase/pyproject.toml new file mode 100644 index 0000000000..287eb147dc --- /dev/null +++ b/api/providers/vdb/vdb-vastbase/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-vastbase" +version = "0.0.1" + +dependencies = [ + "pyobvector~=0.2.17", +] +description = "Dify vector store backend (dify-vdb-vastbase)." + +[project.entry-points."dify.vector_backends"] +vastbase = "dify_vdb_vastbase.vastbase_vector:VastbaseVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/elasticsearch/__init__.py b/api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/elasticsearch/__init__.py rename to api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/__init__.py diff --git a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py b/api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/vastbase_vector.py similarity index 99% rename from api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py rename to api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/vastbase_vector.py index d080e8da58..ab00f9db28 100644 --- a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py +++ b/api/providers/vdb/vdb-vastbase/src/dify_vdb_vastbase/vastbase_vector.py @@ -28,7 +28,7 @@ class VastbaseVectorConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict): + def validate_config(cls, values: dict[str, Any]): if not values["host"]: raise ValueError("config VASTBASE_HOST is required") if not values["port"]: diff --git a/api/tests/integration_tests/vdb/pyvastbase/test_vastbase_vector.py b/api/providers/vdb/vdb-vastbase/tests/integration_tests/test_vastbase_vector.py similarity index 72% rename from api/tests/integration_tests/vdb/pyvastbase/test_vastbase_vector.py rename to api/providers/vdb/vdb-vastbase/tests/integration_tests/test_vastbase_vector.py index a47f13625c..0467dec37a 100644 --- a/api/tests/integration_tests/vdb/pyvastbase/test_vastbase_vector.py +++ b/api/providers/vdb/vdb-vastbase/tests/integration_tests/test_vastbase_vector.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVector, VastbaseVectorConfig -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_vastbase.vastbase_vector import VastbaseVector, VastbaseVectorConfig + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class VastbaseVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py b/api/providers/vdb/vdb-vastbase/tests/unit_tests/test_vastbase_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py rename to api/providers/vdb/vdb-vastbase/tests/unit_tests/test_vastbase_vector.py index bd8df520ba..4dfb956c00 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/pyvastbase/test_vastbase_vector.py +++ b/api/providers/vdb/vdb-vastbase/tests/unit_tests/test_vastbase_vector.py @@ -41,7 +41,7 @@ def vastbase_module(monkeypatch): for name, module in _build_fake_psycopg2_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.pyvastbase.vastbase_vector as module + import dify_vdb_vastbase.vastbase_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-vikingdb/pyproject.toml b/api/providers/vdb/vdb-vikingdb/pyproject.toml new file mode 100644 index 0000000000..fdf59f76a4 --- /dev/null +++ b/api/providers/vdb/vdb-vikingdb/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-vikingdb" +version = "0.0.1" + +dependencies = [ + "volcengine-compat~=1.0.0", +] +description = "Dify vector store backend (dify-vdb-vikingdb)." + +[project.entry-points."dify.vector_backends"] +vikingdb = "dify_vdb_vikingdb.vikingdb_vector:VikingDBVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/hologres/__init__.py b/api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/hologres/__init__.py rename to api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/__init__.py diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py similarity index 100% rename from api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py rename to api/providers/vdb/vdb-vikingdb/src/dify_vdb_vikingdb/vikingdb_vector.py diff --git a/api/tests/integration_tests/vdb/__mock/vikingdb.py b/api/providers/vdb/vdb-vikingdb/tests/integration_tests/conftest.py similarity index 100% rename from api/tests/integration_tests/vdb/__mock/vikingdb.py rename to api/providers/vdb/vdb-vikingdb/tests/integration_tests/conftest.py diff --git a/api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py b/api/providers/vdb/vdb-vikingdb/tests/integration_tests/test_vikingdb.py similarity index 78% rename from api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py rename to api/providers/vdb/vdb-vikingdb/tests/integration_tests/test_vikingdb.py index 56311acd25..5a3908d14b 100644 --- a/api/tests/integration_tests/vdb/vikingdb/test_vikingdb.py +++ b/api/providers/vdb/vdb-vikingdb/tests/integration_tests/test_vikingdb.py @@ -1,10 +1,6 @@ -from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector -from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text +from dify_vdb_vikingdb.vikingdb_vector import VikingDBConfig, VikingDBVector -pytest_plugins = ( - "tests.integration_tests.vdb.test_vector_store", - "tests.integration_tests.vdb.__mock.vikingdb", -) +from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text class VikingDBVectorTest(AbstractVectorTest): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py b/api/providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py similarity index 99% rename from api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py rename to api/providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py index 9da92af2d0..544b8163be 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/vikingdb/test_vikingdb_vector.py +++ b/api/providers/vdb/vdb-vikingdb/tests/unit_tests/test_vikingdb_vector.py @@ -83,7 +83,7 @@ def vikingdb_module(monkeypatch): for name, module in _build_fake_vikingdb_modules().items(): monkeypatch.setitem(sys.modules, name, module) - import core.rag.datasource.vdb.vikingdb.vikingdb_vector as module + import dify_vdb_vikingdb.vikingdb_vector as module return importlib.reload(module) diff --git a/api/providers/vdb/vdb-weaviate/pyproject.toml b/api/providers/vdb/vdb-weaviate/pyproject.toml new file mode 100644 index 0000000000..035fbd396d --- /dev/null +++ b/api/providers/vdb/vdb-weaviate/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "dify-vdb-weaviate" +version = "0.0.1" + +dependencies = [ + "weaviate-client==4.20.5", +] +description = "Dify vector store backend (dify-vdb-weaviate)." + +[project.entry-points."dify.vector_backends"] +weaviate = "dify_vdb_weaviate.weaviate_vector:WeaviateVectorFactory" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/api/tests/integration_tests/vdb/huawei/__init__.py b/api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/huawei/__init__.py rename to api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/__init__.py diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py similarity index 90% rename from api/core/rag/datasource/vdb/weaviate/weaviate_vector.py rename to api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py index 25b65b82a9..902e6a03a8 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/providers/vdb/vdb-weaviate/src/dify_vdb_weaviate/weaviate_vector.py @@ -20,7 +20,7 @@ from pydantic import BaseModel, model_validator from weaviate.classes.data import DataObject from weaviate.classes.init import Auth from weaviate.classes.query import Filter, MetadataQuery -from weaviate.exceptions import UnexpectedStatusCodeError +from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateQueryError from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -82,7 +82,7 @@ class WeaviateConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_config(cls, values: dict) -> dict: + def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]: """Validates that required configuration values are present.""" if not values["endpoint"]: raise ValueError("config WEAVIATE_ENDPOINT is required") @@ -230,6 +230,8 @@ class WeaviateVector(BaseVector): wc.Property(name="doc_id", data_type=wc.DataType.TEXT), wc.Property(name="doc_type", data_type=wc.DataType.TEXT), wc.Property(name="chunk_index", data_type=wc.DataType.INT), + wc.Property(name="is_summary", data_type=wc.DataType.BOOL), + wc.Property(name="original_chunk_id", data_type=wc.DataType.TEXT), ], vector_config=wc.Configure.Vectors.self_provided(), ) @@ -262,6 +264,10 @@ class WeaviateVector(BaseVector): to_add.append(wc.Property(name="doc_type", data_type=wc.DataType.TEXT)) if "chunk_index" not in existing: to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT)) + if "is_summary" not in existing: + to_add.append(wc.Property(name="is_summary", data_type=wc.DataType.BOOL)) + if "original_chunk_id" not in existing: + to_add.append(wc.Property(name="original_chunk_id", data_type=wc.DataType.TEXT)) for prop in to_add: try: @@ -400,15 +406,27 @@ class WeaviateVector(BaseVector): top_k = int(kwargs.get("top_k", 4)) score_threshold = float(kwargs.get("score_threshold") or 0.0) - res = col.query.near_vector( - near_vector=query_vector, - limit=top_k, - return_properties=props, - return_metadata=MetadataQuery(distance=True), - include_vector=False, - filters=where, - target_vector="default", - ) + try: + res = col.query.near_vector( + near_vector=query_vector, + limit=top_k, + return_properties=props, + return_metadata=MetadataQuery(distance=True), + include_vector=False, + filters=where, + target_vector="default", + ) + except WeaviateQueryError: + self._ensure_properties() + res = col.query.near_vector( + near_vector=query_vector, + limit=top_k, + return_properties=props, + return_metadata=MetadataQuery(distance=True), + include_vector=False, + filters=where, + target_vector="default", + ) docs: list[Document] = [] for obj in res.objects: @@ -446,14 +464,25 @@ class WeaviateVector(BaseVector): top_k = int(kwargs.get("top_k", 4)) - res = col.query.bm25( - query=query, - query_properties=[Field.TEXT_KEY.value], - limit=top_k, - return_properties=props, - include_vector=True, - filters=where, - ) + try: + res = col.query.bm25( + query=query, + query_properties=[Field.TEXT_KEY.value], + limit=top_k, + return_properties=props, + include_vector=True, + filters=where, + ) + except WeaviateQueryError: + self._ensure_properties() + res = col.query.bm25( + query=query, + query_properties=[Field.TEXT_KEY.value], + limit=top_k, + return_properties=props, + include_vector=True, + filters=where, + ) docs: list[Document] = [] for obj in res.objects: diff --git a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py b/api/providers/vdb/vdb-weaviate/tests/integration_tests/test_weaviate.py similarity index 72% rename from api/tests/integration_tests/vdb/weaviate/test_weaviate.py rename to api/providers/vdb/vdb-weaviate/tests/integration_tests/test_weaviate.py index a1d9850979..631d23d653 100644 --- a/api/tests/integration_tests/vdb/weaviate/test_weaviate.py +++ b/api/providers/vdb/vdb-weaviate/tests/integration_tests/test_weaviate.py @@ -1,10 +1,9 @@ -from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector -from tests.integration_tests.vdb.test_vector_store import ( +from dify_vdb_weaviate.weaviate_vector import WeaviateConfig, WeaviateVector + +from core.rag.datasource.vdb.vector_integration_test_support import ( AbstractVectorTest, ) -pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",) - class WeaviateVectorTest(AbstractVectorTest): def __init__(self): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weavaite.py b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weavaite.py similarity index 92% rename from api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weavaite.py rename to api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weavaite.py index baf8c9e5f8..c773e4d552 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weavaite.py +++ b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weavaite.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock, patch -from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector +from dify_vdb_weaviate.weaviate_vector import WeaviateConfig, WeaviateVector def test_init_client_with_valid_config(): diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py similarity index 84% rename from api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py rename to api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py index 69d1833001..b40f7e52ca 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py +++ b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py @@ -14,9 +14,9 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from dify_vdb_weaviate import weaviate_vector as weaviate_vector_module +from dify_vdb_weaviate.weaviate_vector import WeaviateConfig, WeaviateVector -from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module -from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector from core.rag.models.document import Document @@ -40,7 +40,7 @@ class TestWeaviateVector(unittest.TestCase): with pytest.raises(ValueError, match="config WEAVIATE_ENDPOINT is required"): WeaviateConfig(endpoint="") - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def _create_weaviate_vector(self, mock_weaviate_module): """Helper to create a WeaviateVector instance with mocked client.""" mock_client = MagicMock() @@ -66,7 +66,7 @@ class TestWeaviateVector(unittest.TestCase): mock_client.close.assert_called_once() mock_debug.assert_called_once() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate.connect_to_custom") def test_init_client_reuses_cached_client_without_reconnect(self, mock_connect): cached_client = MagicMock() cached_client.is_ready.return_value = True @@ -79,7 +79,7 @@ class TestWeaviateVector(unittest.TestCase): assert client is cached_client mock_connect.assert_not_called() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate.connect_to_custom") def test_init_client_reuses_cached_client_after_lock_recheck(self, mock_connect): cached_client = MagicMock() cached_client.is_ready.side_effect = [False, True] @@ -92,8 +92,8 @@ class TestWeaviateVector(unittest.TestCase): assert client is cached_client mock_connect.assert_not_called() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.Auth.api_key", return_value="auth-token") - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + @patch("dify_vdb_weaviate.weaviate_vector.Auth.api_key", return_value="auth-token") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate.connect_to_custom") def test_init_client_parses_custom_grpc_endpoint_without_scheme(self, mock_connect, mock_api_key): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -122,7 +122,7 @@ class TestWeaviateVector(unittest.TestCase): } mock_api_key.assert_called_once_with("test-key") - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate.connect_to_custom") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate.connect_to_custom") def test_init_client_raises_when_database_not_ready(self, mock_connect): mock_client = MagicMock() mock_client.is_ready.return_value = False @@ -133,7 +133,7 @@ class TestWeaviateVector(unittest.TestCase): with pytest.raises(ConnectionError, match="Vector database is not ready"): wv._init_client(self.config) - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_init(self, mock_weaviate_module): """Test WeaviateVector initialization stores attributes including doc_type.""" mock_client = MagicMock() @@ -183,9 +183,9 @@ class TestWeaviateVector(unittest.TestCase): wv._create_collection.assert_called_once() wv.add_texts.assert_called_once_with([doc], [[0.1, 0.2]]) - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config") - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.redis_client") + @patch("dify_vdb_weaviate.weaviate_vector.dify_config") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_create_collection_includes_doc_type_property(self, mock_weaviate_module, mock_dify_config, mock_redis): """Test that _create_collection defines doc_type in the schema properties.""" # Mock Redis @@ -232,7 +232,7 @@ class TestWeaviateVector(unittest.TestCase): f"doc_type should be in collection schema properties, got: {property_names}" ) - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + @patch("dify_vdb_weaviate.weaviate_vector.redis_client") def test_create_collection_returns_early_when_cache_key_exists(self, mock_redis): mock_lock = MagicMock() mock_lock.__enter__ = MagicMock() @@ -251,7 +251,7 @@ class TestWeaviateVector(unittest.TestCase): wv._ensure_properties.assert_not_called() mock_redis.set.assert_not_called() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + @patch("dify_vdb_weaviate.weaviate_vector.redis_client") def test_create_collection_logs_and_reraises_errors(self, mock_redis): mock_lock = MagicMock() mock_lock.__enter__ = MagicMock() @@ -270,7 +270,7 @@ class TestWeaviateVector(unittest.TestCase): mock_exception.assert_called_once() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties adds doc_type when it's missing from existing schema.""" mock_client = MagicMock() @@ -305,7 +305,7 @@ class TestWeaviateVector(unittest.TestCase): added_names = [call.args[0].name for call in add_calls] assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}" - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_ensure_properties_adds_all_missing_core_properties(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -326,9 +326,9 @@ class TestWeaviateVector(unittest.TestCase): add_calls = mock_col.config.add_property.call_args_list added_names = [call.args[0].name for call in add_calls] - assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index"] + assert added_names == ["document_id", "doc_id", "doc_type", "chunk_index", "is_summary", "original_chunk_id"] - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module): """Test that _ensure_properties does not add doc_type when it already exists.""" mock_client = MagicMock() @@ -346,6 +346,8 @@ class TestWeaviateVector(unittest.TestCase): SimpleNamespace(name="doc_id"), SimpleNamespace(name="doc_type"), SimpleNamespace(name="chunk_index"), + SimpleNamespace(name="is_summary"), + SimpleNamespace(name="original_chunk_id"), ] mock_cfg = MagicMock() mock_cfg.properties = existing_props @@ -361,7 +363,7 @@ class TestWeaviateVector(unittest.TestCase): # No properties should be added mock_col.config.add_property.assert_not_called() - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_ensure_properties_logs_warning_when_property_addition_fails(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -383,9 +385,9 @@ class TestWeaviateVector(unittest.TestCase): with patch.object(weaviate_vector_module.logger, "warning") as mock_warning: wv._ensure_properties() - assert mock_warning.call_count == 4 + assert mock_warning.call_count == 6 - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_vector returns doc_type in document metadata. @@ -432,7 +434,7 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_vector_uses_document_filter_and_default_distance(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -469,7 +471,7 @@ class TestWeaviateVector(unittest.TestCase): assert docs[0].metadata["score"] == 0.0 assert mock_col.query.near_vector.call_args.kwargs["filters"] is not None - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_vector_returns_empty_when_collection_is_missing(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -484,7 +486,57 @@ class TestWeaviateVector(unittest.TestCase): assert wv.search_by_vector(query_vector=[0.1] * 3) == [] - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") + def test_search_by_vector_retries_on_weaviate_query_error(self, mock_weaviate_module): + """Test that search_by_vector catches WeaviateQueryError, calls _ensure_properties, and retries.""" + from weaviate.exceptions import WeaviateQueryError + + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + # First call raises WeaviateQueryError, second call succeeds + mock_obj = MagicMock() + mock_obj.properties = {"text": "retry result", "document_id": "doc-1"} + mock_obj.metadata.distance = 0.2 + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + + mock_col.query.near_vector.side_effect = [ + WeaviateQueryError("missing property", "gRPC"), + mock_result, + ] + + # Mock _ensure_properties dependencies + mock_cfg = MagicMock() + mock_cfg.properties = [ + SimpleNamespace(name="text"), + SimpleNamespace(name="document_id"), + SimpleNamespace(name="doc_id"), + SimpleNamespace(name="doc_type"), + SimpleNamespace(name="chunk_index"), + SimpleNamespace(name="is_summary"), + SimpleNamespace(name="original_chunk_id"), + ] + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_vector(query_vector=[0.1] * 3, top_k=1) + + assert mock_col.query.near_vector.call_count == 2 + assert len(docs) == 1 + assert docs[0].metadata["score"] == pytest.approx(0.8) + + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module): """Test that search_by_full_text also returns doc_type in document metadata.""" mock_client = MagicMock() @@ -526,7 +578,7 @@ class TestWeaviateVector(unittest.TestCase): assert len(docs) == 1 assert docs[0].metadata.get("doc_type") == "image" - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_full_text_uses_document_filter(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -554,7 +606,7 @@ class TestWeaviateVector(unittest.TestCase): assert docs[0].vector == [0.3, 0.4] assert mock_col.query.bm25.call_args.kwargs["filters"] is not None - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_search_by_full_text_returns_empty_when_collection_is_missing(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -569,7 +621,57 @@ class TestWeaviateVector(unittest.TestCase): assert wv.search_by_full_text(query="missing") == [] - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_retries_on_weaviate_query_error(self, mock_weaviate_module): + """Test that search_by_full_text catches WeaviateQueryError, calls _ensure_properties, and retries.""" + from weaviate.exceptions import WeaviateQueryError + + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + # First call raises WeaviateQueryError, second call succeeds + mock_obj = MagicMock() + mock_obj.properties = {"text": "retry bm25 result", "doc_id": "segment-1"} + mock_obj.vector = {"default": [0.5, 0.6]} + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + + mock_col.query.bm25.side_effect = [ + WeaviateQueryError("missing property", "gRPC"), + mock_result, + ] + + # Mock _ensure_properties dependencies + mock_cfg = MagicMock() + mock_cfg.properties = [ + SimpleNamespace(name="text"), + SimpleNamespace(name="document_id"), + SimpleNamespace(name="doc_id"), + SimpleNamespace(name="doc_type"), + SimpleNamespace(name="chunk_index"), + SimpleNamespace(name="is_summary"), + SimpleNamespace(name="original_chunk_id"), + ] + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_full_text(query="retry", top_k=1) + + assert mock_col.query.bm25.call_count == 2 + assert len(docs) == 1 + assert docs[0].page_content == "retry bm25 result" + + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module): """Test that add_texts includes doc_type from document metadata in stored properties.""" mock_client = MagicMock() @@ -611,7 +713,7 @@ class TestWeaviateVector(unittest.TestCase): stored_props = call_kwargs.kwargs.get("properties") assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}" - @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + @patch("dify_vdb_weaviate.weaviate_vector.weaviate") def test_add_texts_falls_back_to_random_uuid_and_serializes_datetime_metadata(self, mock_weaviate_module): mock_client = MagicMock() mock_client.is_ready.return_value = True @@ -635,7 +737,7 @@ class TestWeaviateVector(unittest.TestCase): with ( patch.object(wv, "_get_uuids", return_value=["not-a-uuid"]), - patch("core.rag.datasource.vdb.weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"), + patch("dify_vdb_weaviate.weaviate_vector._uuid.uuid4", return_value="fallback-uuid"), ): ids = wv.add_texts(documents=[doc], embeddings=[[]]) @@ -775,9 +877,7 @@ class TestWeaviateVectorFactory(unittest.TestCase): patch.object(weaviate_vector_module.dify_config, "WEAVIATE_GRPC_ENDPOINT", "localhost:50051"), patch.object(weaviate_vector_module.dify_config, "WEAVIATE_API_KEY", "api-key"), patch.object(weaviate_vector_module.dify_config, "WEAVIATE_BATCH_SIZE", 88), - patch( - "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" - ) as mock_vector, + patch("dify_vdb_weaviate.weaviate_vector.WeaviateVector", return_value="vector") as mock_vector, ): factory = weaviate_vector_module.WeaviateVectorFactory() result = factory.init_vector(dataset, attributes, MagicMock()) @@ -806,9 +906,7 @@ class TestWeaviateVectorFactory(unittest.TestCase): "gen_collection_name_by_id", return_value="GeneratedCollection_Node", ), - patch( - "core.rag.datasource.vdb.weaviate.weaviate_vector.WeaviateVector", return_value="vector" - ) as mock_vector, + patch("dify_vdb_weaviate.weaviate_vector.WeaviateVector", return_value="vector") as mock_vector, ): factory = weaviate_vector_module.WeaviateVectorFactory() result = factory.init_vector(dataset, attributes, MagicMock()) diff --git a/api/pyproject.toml b/api/pyproject.toml index 086ce5bb72..69add5c68d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,96 +1,53 @@ [project] name = "dify-api" -version = "1.13.3" +version = "1.14.0" requires-python = "~=3.12.0" dependencies = [ - "aliyun-log-python-sdk~=0.9.37", - "arize-phoenix-otel~=0.15.0", - "azure-identity==1.25.3", - "beautifulsoup4==4.14.3", - "boto3==1.42.83", - "bs4~=0.0.1", - "cachetools~=5.3.0", - "celery~=5.6.2", - "charset-normalizer>=3.4.4", - "flask~=3.1.2", - "flask-compress>=1.17,<1.25", - "flask-cors~=6.0.0", - "flask-login~=0.6.3", - "flask-migrate~=4.1.0", - "flask-orjson~=2.0.0", - "flask-sqlalchemy~=3.1.1", - "gevent~=25.9.1", - "gmpy2~=2.3.0", - "google-api-core>=2.19.1", - "google-api-python-client==2.193.0", - "google-auth>=2.47.0", - "google-auth-httplib2==0.3.1", - "google-cloud-aiplatform>=1.123.0", - "googleapis-common-protos>=1.65.0", - "graphon>=0.1.2", - "gunicorn~=25.3.0", - "httpx[socks]~=0.28.0", - "jieba==0.42.1", - "json-repair>=0.55.1", - "langfuse>=3.0.0,<5.0.0", - "langsmith~=0.7.16", - "markdown~=3.10.2", - "mlflow-skinny>=3.0.0", - "numpy~=1.26.4", - "openpyxl~=3.1.5", - "opik~=1.10.37", - "litellm==1.83.0", # Pinned to avoid madoka dependency issue - "opentelemetry-api==1.40.0", - "opentelemetry-distro==0.61b0", - "opentelemetry-exporter-otlp==1.40.0", - "opentelemetry-exporter-otlp-proto-common==1.40.0", - "opentelemetry-exporter-otlp-proto-grpc==1.40.0", - "opentelemetry-exporter-otlp-proto-http==1.40.0", - "opentelemetry-instrumentation==0.61b0", - "opentelemetry-instrumentation-celery==0.61b0", - "opentelemetry-instrumentation-flask==0.61b0", - "opentelemetry-instrumentation-httpx==0.61b0", - "opentelemetry-instrumentation-redis==0.61b0", - "opentelemetry-instrumentation-sqlalchemy==0.61b0", - "opentelemetry-propagator-b3==1.40.0", - "opentelemetry-proto==1.40.0", - "opentelemetry-sdk==1.40.0", - "opentelemetry-semantic-conventions==0.61b0", - "opentelemetry-util-http==0.61b0", - "pandas[excel,output-formatting,performance]~=3.0.1", - "psycogreen~=1.0.2", - "psycopg2-binary~=2.9.6", - "pycryptodome==3.23.0", - "pydantic~=2.12.5", - "pydantic-settings~=2.13.1", - "pyjwt~=2.12.0", - "pypdfium2==5.6.0", - "python-docx~=1.2.0", - "python-dotenv==1.2.2", - "pyyaml~=6.0.1", - "readabilipy~=0.3.0", - "redis[hiredis]~=7.4.0", - "resend~=2.26.0", - "sentry-sdk[flask]~=2.55.0", - "sqlalchemy~=2.0.29", - "starlette==1.0.0", - "tiktoken~=0.12.0", - "transformers~=5.3.0", - "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", - "pypandoc~=1.13", - "yarl~=1.23.0", - "sseclient-py~=1.9.0", + # Legacy: mature and widely deployed + "bleach>=6.3.0", + "boto3>=1.43.3", + "celery>=5.6.3", + "croniter>=6.2.2", + "flask>=3.1.3,<4.0.0", + "flask-cors>=6.0.2", + "gevent>=26.4.0", + "gevent-websocket>=0.10.1", + "gmpy2>=2.3.0", + "google-api-python-client>=2.195.0", + "gunicorn>=25.3.0", + "psycogreen>=1.0.2", + "psycopg2-binary>=2.9.12", + "python-socketio>=5.13.0", + "redis[hiredis]>=7.4.0", + "sendgrid>=6.12.5", + "sseclient-py>=1.8.0", + + # Stable: production-proven, cap below the next major + "aliyun-log-python-sdk>=0.9.44,<1.0.0", + "azure-identity>=1.25.3,<2.0.0", + "flask-compress>=1.24,<2.0.0", + "flask-login>=0.6.3,<1.0.0", + "flask-migrate>=4.1.0,<5.0.0", + "flask-orjson>=2.0.0,<3.0.0", + "flask-restx>=1.3.2,<2.0.0", + "google-cloud-aiplatform>=1.149.0,<2.0.0", + "httpx[socks]>=0.28.1,<1.0.0", + "opentelemetry-distro>=0.62b1,<1.0.0", + "opentelemetry-instrumentation-celery>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-flask>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-httpx>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-redis>=0.62b0,<1.0.0", + "opentelemetry-instrumentation-sqlalchemy>=0.62b0,<1.0.0", + "opentelemetry-propagator-b3>=1.41.1,<2.0.0", + "readabilipy>=0.3.0,<1.0.0", + "resend>=2.27.0,<3.0.0", + + # Emerging: newer and fast-moving, use compatible pins + "fastopenapi[flask]~=0.7.0", + "graphon~=0.2.2", "httpx-sse~=0.4.0", - "sendgrid~=6.12.3", - "flask-restx~=1.3.2", - "packaging~=23.2", - "croniter>=6.0.0", - "weaviate-client==4.20.4", - "apscheduler>=3.11.0", - "weave>=0.52.16", - "fastopenapi[flask]>=0.7.0", - "bleach~=6.3.0", + "json-repair~=0.59.4", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -98,9 +55,56 @@ dependencies = [ [tool.setuptools] packages = [] +[tool.uv.workspace] +members = ["providers/vdb/*", "providers/trace/*"] +exclude = ["providers/vdb/__pycache__", "providers/trace/__pycache__"] + +[tool.uv.sources] +dify-vdb-alibabacloud-mysql = { workspace = true } +dify-vdb-analyticdb = { workspace = true } +dify-vdb-baidu = { workspace = true } +dify-vdb-chroma = { workspace = true } +dify-vdb-clickzetta = { workspace = true } +dify-vdb-couchbase = { workspace = true } +dify-vdb-elasticsearch = { workspace = true } +dify-vdb-hologres = { workspace = true } +dify-vdb-huawei-cloud = { workspace = true } +dify-vdb-iris = { workspace = true } +dify-vdb-lindorm = { workspace = true } +dify-vdb-matrixone = { workspace = true } +dify-vdb-milvus = { workspace = true } +dify-vdb-myscale = { workspace = true } +dify-vdb-oceanbase = { workspace = true } +dify-vdb-opengauss = { workspace = true } +dify-vdb-opensearch = { workspace = true } +dify-vdb-oracle = { workspace = true } +dify-vdb-pgvecto-rs = { workspace = true } +dify-vdb-pgvector = { workspace = true } +dify-vdb-qdrant = { workspace = true } +dify-vdb-relyt = { workspace = true } +dify-vdb-tablestore = { workspace = true } +dify-vdb-tencent = { workspace = true } +dify-vdb-tidb-on-qdrant = { workspace = true } +dify-vdb-tidb-vector = { workspace = true } +dify-vdb-upstash = { workspace = true } +dify-vdb-vastbase = { workspace = true } +dify-vdb-vikingdb = { workspace = true } +dify-vdb-weaviate = { workspace = true } +dify-trace-aliyun = { workspace = true } +dify-trace-arize-phoenix = { workspace = true } +dify-trace-langfuse = { workspace = true } +dify-trace-langsmith = { workspace = true } +dify-trace-mlflow = { workspace = true } +dify-trace-opik = { workspace = true } +dify-trace-tencent = { workspace = true } +dify-trace-weave = { workspace = true } + [tool.uv] -default-groups = ["storage", "tools", "vdb"] +default-groups = ["storage", "tools", "vdb-all", "trace-all"] package = false +override-dependencies = [ + "pyarrow>=18.0.0", +] [dependency-groups] @@ -109,69 +113,69 @@ package = false # Required for development and running tests ############################################################ dev = [ - "coverage~=7.13.4", - "dotenv-linter~=0.7.0", - "faker~=40.12.0", - "lxml-stubs~=0.5.1", - "basedpyright~=1.39.0", - "ruff~=0.15.5", - "pytest~=9.0.2", - "pytest-benchmark~=5.2.3", - "pytest-cov~=7.1.0", - "pytest-env~=1.6.0", - "pytest-mock~=3.15.1", - "testcontainers~=4.14.1", - "types-aiofiles~=25.1.0", - "types-beautifulsoup4~=4.12.0", - "types-cachetools~=6.2.0", - "types-colorama~=0.4.15", - "types-defusedxml~=0.7.0", - "types-deprecated~=1.3.1", - "types-docutils~=0.22.3", - "types-flask-cors~=6.0.0", - "types-flask-migrate~=4.1.0", - "types-gevent~=25.9.0", - "types-greenlet~=3.3.0", - "types-html5lib~=1.1.11", - "types-markdown~=3.10.2", - "types-oauthlib~=3.3.0", - "types-objgraph~=3.6.0", - "types-olefile~=0.47.0", - "types-openpyxl~=3.1.5", - "types-pexpect~=4.9.0", - "types-protobuf~=7.34.1", - "types-psutil~=7.2.2", - "types-psycopg2~=2.9.21", - "types-pygments~=2.20.0", - "types-pymysql~=1.1.0", - "types-python-dateutil~=2.9.0", - "types-pywin32~=311.0.0", - "types-pyyaml~=6.0.12", - "types-regex~=2026.4.4", - "types-shapely~=2.1.0", - "types-simplejson>=3.20.0", - "types-six>=1.17.0", - "types-tensorflow>=2.18.0", - "types-tqdm>=4.67.0", + "coverage>=7.13.4", + "dotenv-linter>=0.7.0", + "faker>=40.15.0", + "lxml-stubs>=0.5.1", + "basedpyright>=1.39.3", + "ruff>=0.15.12", + "pytest>=9.0.3", + "pytest-benchmark>=5.2.3", + "pytest-cov>=7.1.0", + "pytest-env>=1.6.0", + "pytest-mock>=3.15.1", + "testcontainers>=4.14.2", + "types-aiofiles>=25.1.0", + "types-beautifulsoup4>=4.12.0", + "types-cachetools>=7.0.0.20260503", + "types-colorama>=0.4.15", + "types-defusedxml>=0.7.0", + "types-deprecated>=1.3.1", + "types-docutils>=0.22.3", + "types-flask-cors>=6.0.0", + "types-flask-migrate>=4.1.0", + "types-gevent>=26.4.0", + "types-greenlet>=3.5.0.20260428", + "types-html5lib>=1.1.11", + "types-markdown>=3.10.2", + "types-oauthlib>=3.3.0", + "types-objgraph>=3.6.0", + "types-olefile>=0.47.0", + "types-openpyxl>=3.1.5", + "types-pexpect>=4.9.0", + "types-protobuf>=7.34.1.20260503", + "types-psutil>=7.2.2", + "types-psycopg2>=2.9.21.20260422", + "types-pygments>=2.20.0", + "types-pymysql>=1.1.0", + "types-python-dateutil>=2.9.0", + "types-pywin32>=311.0.0", + "types-pyyaml>=6.0.12", + "types-regex>=2026.4.4", + "types-shapely>=2.1.0", + "types-simplejson>=3.20.0.20260408", + "types-six>=1.17.0.20260408", + "types-tensorflow>=2.18.0.20260408", + "types-tqdm>=4.67.3.20260408", "types-ujson>=5.10.0", - "boto3-stubs>=1.38.20", - "types-jmespath>=1.0.2.20240106", - "hypothesis>=6.131.15", + "boto3-stubs>=1.43.2", + "types-jmespath>=1.1.0.20260408", + "hypothesis>=6.152.4", "types_pyOpenSSL>=24.1.0", - "types_cffi>=1.17.0", - "types_setuptools>=80.9.0", - "pandas-stubs~=3.0.0", - "scipy-stubs>=1.15.3.0", - "types-python-http-client>=3.3.7.20240910", + "types_cffi>=2.0.0.20260429", + "types_setuptools>=82.0.0.20260408", + "pandas-stubs>=3.0.0", + "scipy-stubs>=1.17.1.4", + "types-python-http-client>=3.3.7.20260408", "import-linter>=2.3", "types-redis>=4.6.0.20241004", "celery-types>=0.23.0", - "mypy~=1.20.0", + "mypy>=1.20.2", # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. - "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.60.0", + "pyrefly>=0.64.0", + "xinference-client>=2.7.0", ] ############################################################ @@ -179,54 +183,110 @@ dev = [ # Required for storage clients ############################################################ storage = [ - "azure-storage-blob==12.28.0", - "bce-python-sdk~=0.9.23", - "cos-python-sdk-v5==1.9.41", - "esdk-obs-python==3.26.2", - "google-cloud-storage>=3.0.0", - "opendal~=0.46.0", - "oss2==2.19.1", - "supabase~=2.18.1", - "tos~=2.9.0", + "azure-storage-blob>=12.28.0", + "bce-python-sdk>=0.9.71", + "cos-python-sdk-v5>=1.9.42", + "esdk-obs-python>=3.22.2", + "google-cloud-storage>=3.10.1", + "opendal>=0.46.0", + "oss2>=2.19.1", + "supabase>=2.29.0", + "tos>=2.9.0", ] ############################################################ # [ Tools ] dependency group ############################################################ -tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] +tools = ["cloudscraper>=1.2.71", "nltk>=3.9.1"] ############################################################ -# [ VDB ] dependency group -# Required by vector store clients +# [ VDB ] workspace plugins — hollow packages under providers/vdb/* +# Each declares its own third-party deps and registers dify.vector_backends entry points. +# Use: uv sync --group vdb-all | uv sync --group vdb-qdrant ############################################################ -vdb = [ - "alibabacloud_gpdb20160503~=5.2.0", - "alibabacloud_tea_openapi~=0.4.3", - "chromadb==0.5.20", - "clickhouse-connect~=0.15.0", - "clickzetta-connector-python>=0.8.102", - "couchbase~=4.6.0", - "elasticsearch==8.14.0", - "opensearch-py==3.1.0", - "oracledb==3.4.2", - "pgvecto-rs[sqlalchemy]~=0.2.1", - "pgvector==0.4.2", - "pymilvus~=2.6.10", - "pymochow==2.4.0", - "pyobvector~=0.2.17", - "qdrant-client==1.9.0", - "intersystems-irispython>=5.1.0", - "tablestore==6.4.3", - "tcvectordb~=2.1.0", - "tidb-vector==0.0.15", - "upstash-vector==0.8.0", - "volcengine-compat~=1.0.0", - "weaviate-client==4.20.4", - "xinference-client~=2.4.0", - "mo-vector~=0.1.13", - "mysql-connector-python>=9.3.0", - "holo-search-sdk>=0.4.1", +vdb-all = [ + "dify-vdb-alibabacloud-mysql", + "dify-vdb-analyticdb", + "dify-vdb-baidu", + "dify-vdb-chroma", + "dify-vdb-clickzetta", + "dify-vdb-couchbase", + "dify-vdb-elasticsearch", + "dify-vdb-hologres", + "dify-vdb-huawei-cloud", + "dify-vdb-iris", + "dify-vdb-lindorm", + "dify-vdb-matrixone", + "dify-vdb-milvus", + "dify-vdb-myscale", + "dify-vdb-oceanbase", + "dify-vdb-opengauss", + "dify-vdb-opensearch", + "dify-vdb-oracle", + "dify-vdb-pgvecto-rs", + "dify-vdb-pgvector", + "dify-vdb-qdrant", + "dify-vdb-relyt", + "dify-vdb-tablestore", + "dify-vdb-tencent", + "dify-vdb-tidb-on-qdrant", + "dify-vdb-tidb-vector", + "dify-vdb-upstash", + "dify-vdb-vastbase", + "dify-vdb-vikingdb", + "dify-vdb-weaviate", ] +vdb-alibabacloud-mysql = ["dify-vdb-alibabacloud-mysql"] +vdb-analyticdb = ["dify-vdb-analyticdb"] +vdb-baidu = ["dify-vdb-baidu"] +vdb-chroma = ["dify-vdb-chroma"] +vdb-clickzetta = ["dify-vdb-clickzetta"] +vdb-couchbase = ["dify-vdb-couchbase"] +vdb-elasticsearch = ["dify-vdb-elasticsearch"] +vdb-hologres = ["dify-vdb-hologres"] +vdb-huawei-cloud = ["dify-vdb-huawei-cloud"] +vdb-iris = ["dify-vdb-iris"] +vdb-lindorm = ["dify-vdb-lindorm"] +vdb-matrixone = ["dify-vdb-matrixone"] +vdb-milvus = ["dify-vdb-milvus"] +vdb-myscale = ["dify-vdb-myscale"] +vdb-oceanbase = ["dify-vdb-oceanbase"] +vdb-opengauss = ["dify-vdb-opengauss"] +vdb-opensearch = ["dify-vdb-opensearch"] +vdb-oracle = ["dify-vdb-oracle"] +vdb-pgvecto-rs = ["dify-vdb-pgvecto-rs"] +vdb-pgvector = ["dify-vdb-pgvector"] +vdb-qdrant = ["dify-vdb-qdrant"] +vdb-relyt = ["dify-vdb-relyt"] +vdb-tablestore = ["dify-vdb-tablestore"] +vdb-tencent = ["dify-vdb-tencent"] +vdb-tidb-on-qdrant = ["dify-vdb-tidb-on-qdrant"] +vdb-tidb-vector = ["dify-vdb-tidb-vector"] +vdb-upstash = ["dify-vdb-upstash"] +vdb-vastbase = ["dify-vdb-vastbase"] +vdb-vikingdb = ["dify-vdb-vikingdb"] +vdb-weaviate = ["dify-vdb-weaviate"] +# Optional client used by some tests / integrations (not a vector backend plugin) +vdb-xinference = ["xinference-client>=2.7.0"] + +trace-all = [ + "dify-trace-aliyun", + "dify-trace-arize-phoenix", + "dify-trace-langfuse", + "dify-trace-langsmith", + "dify-trace-mlflow", + "dify-trace-opik", + "dify-trace-tencent", + "dify-trace-weave", +] +trace-aliyun = ["dify-trace-aliyun"] +trace-arize-phoenix = ["dify-trace-arize-phoenix"] +trace-langfuse = ["dify-trace-langfuse"] +trace-langsmith = ["dify-trace-langsmith"] +trace-mlflow = ["dify-trace-mlflow"] +trace-opik = ["dify-trace-opik"] +trace-tencent = ["dify-trace-tencent"] +trace-weave = ["dify-trace-weave"] [tool.pyrefly] project-includes = ["."] diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index 43f604c2de..fbbca24558 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -34,42 +34,18 @@ core/external_data_tool/api/api.py core/llm_generator/llm_generator.py core/llm_generator/output_parser/structured_output.py core/mcp/mcp_client.py -core/ops/aliyun_trace/data_exporter/traceclient.py -core/ops/arize_phoenix_trace/arize_phoenix_trace.py -core/ops/mlflow_trace/mlflow_trace.py +providers/trace/trace-aliyun/src/dify_trace_aliyun/data_exporter/traceclient.py +providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py core/ops/ops_trace_manager.py -core/ops/tencent_trace/client.py -core/ops/tencent_trace/utils.py +providers/trace/trace-tencent/src/dify_trace_tencent/client.py +providers/trace/trace-tencent/src/dify_trace_tencent/utils.py core/plugin/backwards_invocation/base.py core/plugin/backwards_invocation/model.py core/prompt/utils/extract_thread_messages.py core/rag/datasource/keyword/jieba/jieba.py core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py -core/rag/datasource/vdb/analyticdb/analyticdb_vector.py -core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py -core/rag/datasource/vdb/baidu/baidu_vector.py -core/rag/datasource/vdb/chroma/chroma_vector.py -core/rag/datasource/vdb/clickzetta/clickzetta_vector.py -core/rag/datasource/vdb/couchbase/couchbase_vector.py -core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py -core/rag/datasource/vdb/huawei/huawei_cloud_vector.py -core/rag/datasource/vdb/lindorm/lindorm_vector.py -core/rag/datasource/vdb/matrixone/matrixone_vector.py -core/rag/datasource/vdb/milvus/milvus_vector.py -core/rag/datasource/vdb/myscale/myscale_vector.py -core/rag/datasource/vdb/oceanbase/oceanbase_vector.py -core/rag/datasource/vdb/opensearch/opensearch_vector.py -core/rag/datasource/vdb/oracle/oraclevector.py -core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py -core/rag/datasource/vdb/relyt/relyt_vector.py -core/rag/datasource/vdb/tablestore/tablestore_vector.py -core/rag/datasource/vdb/tencent/tencent_vector.py -core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py -core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py -core/rag/datasource/vdb/tidb_vector/tidb_vector.py -core/rag/datasource/vdb/upstash/upstash_vector.py -core/rag/datasource/vdb/vikingdb/vikingdb_vector.py -core/rag/datasource/vdb/weaviate/weaviate_vector.py +providers/vdb/** core/rag/extractor/csv_extractor.py core/rag/extractor/excel_extractor.py core/rag/extractor/firecrawl/firecrawl_app.py diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index a8b884ea81..ac0e2a3a53 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -4,7 +4,9 @@ "tests/", ".venv", "migrations/", - "core/rag" + "core/rag", + "providers/vdb/", + "providers/trace/*/tests", ], "typeCheckingMode": "strict", "allowedUntypedLibraries": [ @@ -36,7 +38,9 @@ "gmpy2", "sendgrid", "sendgrid.helpers.mail", - "holo_search_sdk.types" + "holo_search_sdk.types", + "dify_vdb_qdrant", + "dify_vdb_tidb_on_qdrant" ], "reportUnknownMemberType": "hint", "reportUnknownParameterType": "hint", @@ -47,7 +51,6 @@ "reportMissingTypeArgument": "hint", "reportUnnecessaryComparison": "hint", "reportUnnecessaryIsInstance": "hint", - "reportUntypedFunctionDecorator": "hint", "reportUnnecessaryTypeIgnoreComment": "hint", "reportAttributeAccessIssue": "hint", "pythonVersion": "3.12", diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index 100589804c..72b38e7906 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,11 +38,11 @@ from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol, TypedDict -from graphon.entities.pause_reason import PauseReason -from graphon.enums import WorkflowType from sqlalchemy.orm import Session from core.repositories.factory import WorkflowExecutionRepository +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index d5c6a203b1..44735eb769 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol, cast -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 9267be2636..71a2554a60 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -28,22 +28,21 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa -from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from graphon.enums import WorkflowExecutionStatus, WorkflowType -from graphon.nodes.human_input.entities import FormDefinition from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold -from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm +from models.human_input import HumanInputForm, HumanInputFormRecipient from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository, RunsWithRelatedCountsDict from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -64,6 +63,7 @@ class _WorkflowRunError(Exception): def _build_human_input_required_reason( reason_model: WorkflowPauseReason, form_model: HumanInputForm | None, + recipients: Sequence[HumanInputFormRecipient] = (), ) -> HumanInputRequired: form_content = "" inputs = [] @@ -90,7 +90,7 @@ def _build_human_input_required_reason( resolved_default_values = dict(definition.default_values) node_title = definition.node_title or node_title - return HumanInputRequired( + reason = HumanInputRequired( form_id=form_id, form_content=form_content, inputs=inputs, @@ -99,6 +99,7 @@ def _build_human_input_required_reason( node_title=node_title, resolved_default_values=resolved_default_values, ) + return reason class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): @@ -744,12 +745,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): # Upload the state file # Create the pause record - pause_model = WorkflowPause() - pause_model.id = str(uuidv7()) - pause_model.workflow_id = workflow_run.workflow_id - pause_model.workflow_run_id = workflow_run.id - pause_model.state_object_key = state_obj_key - pause_model.created_at = naive_utc_now() + pause_model = WorkflowPause( + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=state_obj_key, + ) pause_reason_models = [] for reason in pause_reasons: if isinstance(reason, HumanInputRequired): @@ -806,12 +806,23 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) for form in session.scalars(form_stmt).all(): form_models[form.id] = form + recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = {} + if form_ids: + recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + for recipient in session.scalars(recipient_stmt).all(): + recipients_by_form_id.setdefault(recipient.form_id, []).append(recipient) pause_reasons: list[PauseReason] = [] for reason in pause_reason_models: if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: form_model = form_models.get(reason.form_id) - pause_reasons.append(_build_human_input_required_reason(reason, form_model)) + pause_reasons.append( + _build_human_input_required_reason( + reason, + form_model, + recipients_by_form_id.get(reason.form_id, ()), + ) + ) else: pause_reasons.append(reason.to_entity()) return pause_reasons diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index feba5f7eb6..67f8795d3f 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -7,9 +7,6 @@ from collections import defaultdict from collections.abc import Sequence from typing import Any -from graphon.nodes.human_input.entities import FormDefinition -from graphon.nodes.human_input.enums import HumanInputFormStatus -from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -21,6 +18,9 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/repositories/workflow_collaboration_repository.py b/api/repositories/workflow_collaboration_repository.py new file mode 100644 index 0000000000..000f80496d --- /dev/null +++ b/api/repositories/workflow_collaboration_repository.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import json +from typing import TypedDict + +from extensions.ext_redis import redis_client + +SESSION_STATE_TTL_SECONDS = 3600 +WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:" +WORKFLOW_LEADER_PREFIX = "workflow_leader:" +WS_SID_MAP_PREFIX = "ws_sid_map:" + + +class WorkflowSessionInfo(TypedDict): + user_id: str + username: str + avatar: str | None + sid: str + connected_at: int + + +class SidMapping(TypedDict): + workflow_id: str + user_id: str + + +class WorkflowCollaborationRepository: + def __init__(self) -> None: + self._redis = redis_client + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(redis_client={self._redis})" + + @staticmethod + def workflow_key(workflow_id: str) -> str: + return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}" + + @staticmethod + def leader_key(workflow_id: str) -> str: + return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}" + + @staticmethod + def sid_key(sid: str) -> str: + return f"{WS_SID_MAP_PREFIX}{sid}" + + @staticmethod + def _decode(value: str | bytes | None) -> str | None: + if value is None: + return None + if isinstance(value, bytes): + return value.decode("utf-8") + return value + + def refresh_session_state(self, workflow_id: str, sid: str) -> None: + workflow_key = self.workflow_key(workflow_id) + sid_key = self.sid_key(sid) + if self._redis.exists(workflow_key): + self._redis.expire(workflow_key, SESSION_STATE_TTL_SECONDS) + if self._redis.exists(sid_key): + self._redis.expire(sid_key, SESSION_STATE_TTL_SECONDS) + + def set_session_info(self, workflow_id: str, session_info: WorkflowSessionInfo) -> None: + workflow_key = self.workflow_key(workflow_id) + self._redis.hset(workflow_key, session_info["sid"], json.dumps(session_info)) + self._redis.set( + self.sid_key(session_info["sid"]), + json.dumps({"workflow_id": workflow_id, "user_id": session_info["user_id"]}), + ex=SESSION_STATE_TTL_SECONDS, + ) + self.refresh_session_state(workflow_id, session_info["sid"]) + + def get_sid_mapping(self, sid: str) -> SidMapping | None: + raw = self._redis.get(self.sid_key(sid)) + if not raw: + return None + value = self._decode(raw) + if not value: + return None + try: + return json.loads(value) + except (TypeError, json.JSONDecodeError): + return None + + def delete_session(self, workflow_id: str, sid: str) -> None: + self._redis.hdel(self.workflow_key(workflow_id), sid) + self._redis.delete(self.sid_key(sid)) + + def session_exists(self, workflow_id: str, sid: str) -> bool: + return bool(self._redis.hexists(self.workflow_key(workflow_id), sid)) + + def sid_mapping_exists(self, sid: str) -> bool: + return bool(self._redis.exists(self.sid_key(sid))) + + def get_session_sids(self, workflow_id: str) -> list[str]: + raw_sids = self._redis.hkeys(self.workflow_key(workflow_id)) + decoded_sids: list[str] = [] + for sid in raw_sids: + decoded = self._decode(sid) + if decoded: + decoded_sids.append(decoded) + return decoded_sids + + def list_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]: + sessions_json = self._redis.hgetall(self.workflow_key(workflow_id)) + users: list[WorkflowSessionInfo] = [] + + for session_info_json in sessions_json.values(): + value = self._decode(session_info_json) + if not value: + continue + try: + session_info = json.loads(value) + except (TypeError, json.JSONDecodeError): + continue + + if not isinstance(session_info, dict): + continue + if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info: + continue + + users.append( + { + "user_id": str(session_info["user_id"]), + "username": str(session_info["username"]), + "avatar": session_info.get("avatar"), + "sid": str(session_info["sid"]), + "connected_at": int(session_info.get("connected_at") or 0), + } + ) + + return users + + def get_current_leader(self, workflow_id: str) -> str | None: + raw = self._redis.get(self.leader_key(workflow_id)) + return self._decode(raw) + + def set_leader_if_absent(self, workflow_id: str, sid: str) -> bool: + return bool(self._redis.set(self.leader_key(workflow_id), sid, nx=True, ex=SESSION_STATE_TTL_SECONDS)) + + def set_leader(self, workflow_id: str, sid: str) -> None: + self._redis.set(self.leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS) + + def delete_leader(self, workflow_id: str) -> None: + self._redis.delete(self.leader_key(workflow_id)) + + def expire_leader(self, workflow_id: str) -> None: + self._redis.expire(self.leader_key(workflow_id), SESSION_STATE_TTL_SECONDS) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 6ceb3ef856..e242b0c667 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -1,11 +1,11 @@ import time import click +from dify_vdb_tidb_on_qdrant.tidb_service import TidbService from sqlalchemy import func, select import app from configs import dify_config -from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding from models.enums import TidbAuthBindingStatus @@ -57,6 +57,7 @@ def create_clusters(batch_size): cluster_name=new_cluster["cluster_name"], account=new_cluster["account"], password=new_cluster["password"], + qdrant_endpoint=new_cluster.get("qdrant_endpoint"), active=False, status=TidbAuthBindingStatus.CREATING, ) diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index 8479cdfb0c..2cc0192a4a 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -7,8 +7,8 @@ from sqlalchemy import select import app from configs import dify_config +from core.db.session_factory import session_factory from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_i18n import EmailType, get_email_i18n_service from models import Account, Tenant, TenantAccountJoin @@ -33,67 +33,68 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = db.session.scalars( - select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False) - ).all() - # group by tenant_id - dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) - for dataset_auto_disable_log in dataset_auto_disable_logs: - if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] - dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) - url = f"{dify_config.CONSOLE_WEB_URL}/datasets" - for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): - features = FeatureService.get_features(tenant_id) - plan = features.billing.subscription.plan - if plan != CloudPlan.SANDBOX: - knowledge_details = [] - # check tenant - tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) - if not tenant: - continue - # check current owner - current_owner_join = db.session.scalar( - select(TenantAccountJoin) - .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") - .limit(1) - ) - if not current_owner_join: - continue - account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) - if not account: - continue + with session_factory.create_session() as session: + dataset_auto_disable_logs = session.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified.is_(False)) + ).all() + # group by tenant_id + dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) + for dataset_auto_disable_log in dataset_auto_disable_logs: + if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = [] + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + url = f"{dify_config.CONSOLE_WEB_URL}/datasets" + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + features = FeatureService.get_features(tenant_id) + plan = features.billing.subscription.plan + if plan != CloudPlan.SANDBOX: + knowledge_details = [] + # check tenant + tenant = session.scalar(select(Tenant).where(Tenant.id == tenant_id)) + if not tenant: + continue + # check current owner + current_owner_join = session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") + .limit(1) + ) + if not current_owner_join: + continue + account = session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) + if not account: + continue - dataset_auto_dataset_map = {} # type: ignore + dataset_auto_dataset_map = {} # type: ignore + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id)) + if dataset: + document_count = len(document_ids) + knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") + if knowledge_details: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, + language_code="en-US", + to=account.email, + template_context={ + "userName": account.email, + "knowledge_details": knowledge_details, + "url": url, + }, + ) + + # update notified to True for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map: - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = [] - dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( - dataset_auto_disable_log.document_id - ) - - for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id)) - if dataset: - document_count = len(document_ids) - knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") - if knowledge_details: - email_service = get_email_i18n_service() - email_service.send_email( - email_type=EmailType.DOCUMENT_CLEAN_NOTIFY, - language_code="en-US", - to=account.email, - template_context={ - "userName": account.email, - "knowledge_details": knowledge_details, - "url": url, - }, - ) - - # update notified to True - for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: - dataset_auto_disable_log.notified = True - db.session.commit() + dataset_auto_disable_log.notified = True + session.commit() end_at = time.perf_counter() logger.info(click.style(f"Send document clean notify mail succeeded: latency: {end_at - start_at}", fg="green")) except Exception: diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 10003b1b97..46d1b85aa0 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -2,11 +2,11 @@ import time from collections.abc import Sequence import click +from dify_vdb_tidb_on_qdrant.tidb_service import TidbService from sqlalchemy import select import app from configs import dify_config -from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from extensions.ext_database import db from models.dataset import TidbAuthBinding from models.enums import TidbAuthBindingStatus diff --git a/api/services/account_service.py b/api/services/account_service.py index 4b58b3b697..b6554a3de7 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -9,7 +9,8 @@ from typing import Any, TypedDict, cast from pydantic import BaseModel, TypeAdapter from sqlalchemy import delete, func, select, update -from sqlalchemy.orm import Session, sessionmaker + +from core.db.session_factory import session_factory class InvitationData(TypedDict): @@ -111,6 +112,14 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) class AccountService: + # Phase-bound token metadata for the change-email flow. Tokens carry the + # current phase so that downstream endpoints can enforce proper progression + CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase" + CHANGE_EMAIL_PHASE_OLD = "old_email" + CHANGE_EMAIL_PHASE_OLD_VERIFIED = "old_email_verified" + CHANGE_EMAIL_PHASE_NEW = "new_email" + CHANGE_EMAIL_PHASE_NEW_VERIFIED = "new_email_verified" + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1) email_code_login_rate_limiter = RateLimiter( @@ -575,13 +584,20 @@ class AccountService: raise ValueError("Email must be provided.") if not phase: raise ValueError("phase must be provided.") + if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW): + raise ValueError("phase must be one of old_email or new_email.") if cls.change_email_rate_limiter.is_rate_limited(account_email): from controllers.console.auth.error import EmailChangeRateLimitExceededError raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60)) - code, token = cls.generate_change_email_token(account_email, account, old_email=old_email) + code, token = cls.generate_change_email_token( + account_email, + account, + old_email=old_email, + additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase}, + ) send_change_mail_task.delay( language=language, @@ -800,19 +816,19 @@ class AccountService: return token @staticmethod - def get_account_by_email_with_case_fallback(email: str, session: Session | None = None) -> Account | None: + def get_account_by_email_with_case_fallback(email: str) -> Account | None: """ Retrieve an account by email and fall back to the lowercase email if the original lookup fails. This keeps backward compatibility for older records that stored uppercase emails while the rest of the system gradually normalizes new inputs. """ - query_session = session or db.session - account = query_session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() - if account or email == email.lower(): - return account + with session_factory.create_session() as session: + account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() + if account or email == email.lower(): + return account - return query_session.execute(select(Account).filter_by(email=email.lower())).scalar_one_or_none() + return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none() @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: @@ -1516,8 +1532,7 @@ class RegisterService: check_workspace_member_invite_permission(tenant.id) - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email, session=session) + account = AccountService.get_account_by_email_with_case_fallback(email) if not account: TenantService.check_member_permission(tenant, inviter, None, "add") diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index a6e6b1bae7..5d136e7393 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,4 +1,5 @@ import copy +from typing import Any, TypedDict from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, @@ -15,9 +16,18 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( from models.model import AppMode +class AdvancedPromptTemplateArgs(TypedDict): + """Expected shape of the args dict passed to AdvancedPromptTemplateService.get_prompt.""" + + app_mode: str + model_mode: str + model_name: str + has_context: str + + class AdvancedPromptTemplateService: @classmethod - def get_prompt(cls, args: dict): + def get_prompt(cls, args: AdvancedPromptTemplateArgs) -> dict[str, Any]: app_mode = args["app_mode"] model_mode = args["model_mode"] model_name = args["model_name"] @@ -29,7 +39,7 @@ class AdvancedPromptTemplateService: return cls.get_common_prompt(app_mode, model_mode, has_context) @classmethod - def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str): + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]: context_prompt = copy.deepcopy(CONTEXT) match app_mode: @@ -63,7 +73,7 @@ class AdvancedPromptTemplateService: return {} @classmethod - def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str): + def get_completion_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]: if has_context == "true": prompt_template["completion_prompt_config"]["prompt"]["text"] = ( context + prompt_template["completion_prompt_config"]["prompt"]["text"] @@ -72,7 +82,7 @@ class AdvancedPromptTemplateService: return prompt_template @classmethod - def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str): + def get_chat_prompt(cls, prompt_template: dict[str, Any], has_context: str, context: str) -> dict[str, Any]: if has_context == "true": prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] @@ -81,7 +91,7 @@ class AdvancedPromptTemplateService: return prompt_template @classmethod - def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str): + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict[str, Any]: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) match app_mode: diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index ae5facbec0..0229a1f43a 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -1,11 +1,8 @@ import logging import uuid - -import pandas as pd - -logger = logging.getLogger(__name__) from typing import TypedDict +import pandas as pd from sqlalchemy import delete, or_, select, update from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -24,6 +21,8 @@ from tasks.annotation.disable_annotation_reply_task import disable_annotation_re from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task +logger = logging.getLogger(__name__) + class AnnotationJobStatusDict(TypedDict): job_id: str @@ -46,9 +45,50 @@ class AnnotationSettingDisabledDict(TypedDict): enabled: bool +class EnableAnnotationArgs(TypedDict): + """Expected shape of the args dict passed to enable_app_annotation.""" + + score_threshold: float + embedding_provider_name: str + embedding_model_name: str + + +class UpsertAnnotationArgs(TypedDict, total=False): + """Expected shape of the args dict passed to up_insert_app_annotation_from_message.""" + + answer: str + content: str + message_id: str + question: str + + +class InsertAnnotationArgs(TypedDict): + """Expected shape of the args dict passed to insert_app_annotation_directly.""" + + question: str + answer: str + + +class UpdateAnnotationArgs(TypedDict, total=False): + """Expected shape of the args dict passed to update_app_annotation_directly. + + Both fields are optional at the type level; the service validates at runtime + and raises ValueError if either is missing. + """ + + answer: str + question: str + + +class UpdateAnnotationSettingArgs(TypedDict): + """Expected shape of the args dict passed to update_app_annotation_setting.""" + + score_threshold: float + + class AppAnnotationService: @classmethod - def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: + def up_insert_app_annotation_from_message(cls, args: UpsertAnnotationArgs, app_id: str) -> MessageAnnotation: # get app info current_user, current_tenant_id = current_account_with_tenant() app = db.session.scalar( @@ -62,8 +102,9 @@ class AppAnnotationService: if answer is None: raise ValueError("Either 'answer' or 'content' must be provided") - if args.get("message_id"): - message_id = str(args["message_id"]) + raw_message_id = args.get("message_id") + if raw_message_id: + message_id = str(raw_message_id) message = db.session.scalar( select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1) ) @@ -87,11 +128,19 @@ class AppAnnotationService: account_id=current_user.id, ) else: - question = args.get("question") - if not question: + maybe_question = args.get("question") + if not maybe_question: raise ValueError("'question' is required when 'message_id' is not provided") + question = maybe_question - annotation = MessageAnnotation(app_id=app.id, content=answer, question=question, account_id=current_user.id) + annotation = MessageAnnotation( + app_id=app.id, + conversation_id=None, + message_id=None, + content=answer, + question=question, + account_id=current_user.id, + ) db.session.add(annotation) db.session.commit() @@ -110,7 +159,7 @@ class AppAnnotationService: return annotation @classmethod - def enable_app_annotation(cls, args: dict, app_id: str) -> AnnotationJobStatusDict: + def enable_app_annotation(cls, args: EnableAnnotationArgs, app_id: str) -> AnnotationJobStatusDict: enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: @@ -217,7 +266,7 @@ class AppAnnotationService: return annotations @classmethod - def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: + def insert_app_annotation_directly(cls, args: InsertAnnotationArgs, app_id: str) -> MessageAnnotation: # get app info current_user, current_tenant_id = current_account_with_tenant() app = db.session.scalar( @@ -251,7 +300,7 @@ class AppAnnotationService: return annotation @classmethod - def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): + def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str): # get app info _, current_tenant_id = current_account_with_tenant() app = db.session.scalar( @@ -270,7 +319,11 @@ class AppAnnotationService: if question is None: raise ValueError("'question' is required") - annotation.content = args["answer"] + answer = args.get("answer") + if answer is None: + raise ValueError("'answer' is required") + + annotation.content = answer annotation.question = question db.session.commit() @@ -613,7 +666,7 @@ class AppAnnotationService: @classmethod def update_app_annotation_setting( - cls, app_id: str, annotation_setting_id: str, args: dict + cls, app_id: str, annotation_setting_id: str, args: UpdateAnnotationSettingArgs ) -> AnnotationSettingDict: current_user, current_tenant_id = current_account_with_tenant() # get app info diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 40e1e5f8ab..97aaea3395 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -3,19 +3,13 @@ import hashlib import logging import uuid from collections.abc import Mapping -from typing import cast +from typing import Any, cast from urllib.parse import urlparse from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version from pydantic import BaseModel @@ -23,6 +17,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from core.helper import ssrf_proxy from core.plugin.entities.plugin import PluginDependency from core.trigger.constants import ( @@ -35,6 +30,12 @@ from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerSc from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode from models.model import AppModelConfig, AppModelConfigDict, IconType @@ -50,7 +51,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.6.0" +CURRENT_DSL_VERSION = CURRENT_APP_DSL_VERSION class Import(BaseModel): @@ -400,7 +401,7 @@ class AppDslService: self, *, app: App | None, - data: dict, + data: dict[str, Any], account: Account, name: str | None = None, description: str | None = None, @@ -455,7 +456,7 @@ class AppDslService: app.updated_by = account.id self._session.add(app) - self._session.commit() + self._session.flush() app_was_created.send(app, account=account) # save dependencies @@ -567,7 +568,7 @@ class AppDslService: @classmethod def _append_workflow_export_data( - cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None + cls, *, export_data: dict[str, Any], app_model: App, include_secret: bool, workflow_id: str | None = None ): """ Append workflow export data @@ -620,7 +621,7 @@ class AppDslService: ] @classmethod - def _append_model_config_export_data(cls, export_data: dict, app_model: App): + def _append_model_config_export_data(cls, export_data: dict[str, Any], app_model: App): """ Append model config export data :param export_data: export data diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 17ed98d301..d6c01e9dcc 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -4,7 +4,7 @@ import logging import threading import uuid from collections.abc import Callable, Generator, Mapping -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator @@ -18,12 +18,13 @@ from core.app.features.rate_limiting import RateLimit from core.app.features.rate_limiting.rate_limit import rate_limit_context from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.db import session_factory -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from extensions.otel import AppGenerateHandler, trace_span from models.model import Account, App, AppMode, EndUser from models.workflow import Workflow, WorkflowRun from services.errors.app import QuotaExceededError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError +from services.quota_service import QuotaService, unlimited from services.workflow_service import WorkflowService from tasks.app_generate.workflow_execute_task import AppExecutionParams, workflow_based_app_execution_task @@ -88,7 +89,7 @@ class AppGenerateService: def generate( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, @@ -106,7 +107,7 @@ class AppGenerateService: quota_charge = unlimited() if dify_config.BILLING_ENABLED: try: - quota_charge = QuotaType.WORKFLOW.consume(app_model.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, app_model.tenant_id) except QuotaExceededError: raise InvokeRateLimitError(f"Workflow execution quota limit reached for tenant {app_model.tenant_id}") @@ -116,6 +117,7 @@ class AppGenerateService: request_id = RateLimit.gen_request_key() try: request_id = rate_limit.enter(request_id) + quota_charge.commit() effective_mode = ( AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode ) @@ -162,6 +164,7 @@ class AppGenerateService: invoke_from=invoke_from, streaming=True, call_depth=0, + workflow_run_id=str(uuid.uuid4()), ) payload_json = payload.model_dump_json() @@ -183,6 +186,10 @@ class AppGenerateService: else: # Blocking mode: run synchronously and return JSON instead of SSE # Keep behaviour consistent with WORKFLOW blocking branch. + pause_config = PauseStateLayerConfig( + session_factory=session_factory.get_session_maker(), + state_owner_user_id=workflow.created_by, + ) advanced_generator = AdvancedChatAppGenerator() return rate_limit.generate( advanced_generator.convert_to_event_stream( @@ -194,6 +201,7 @@ class AppGenerateService: invoke_from=invoke_from, workflow_run_id=str(uuid.uuid4()), streaming=False, + pause_state_config=pause_config, ) ), request_id=request_id, @@ -356,11 +364,11 @@ class AppGenerateService: def generate_more_like_this( cls, app_model: App, - user: Union[Account, EndUser], + user: Account | EndUser, message_id: str, invoke_from: InvokeFrom, streaming: bool = True, - ) -> Union[Mapping, Generator]: + ) -> Mapping | Generator: """ Generate more like this :param app_model: app model diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 2013c869af..8252de7753 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,3 +1,5 @@ +from typing import Any + from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager @@ -6,7 +8,7 @@ from models.model import AppMode, AppModelConfigDict class AppModelConfigService: @classmethod - def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict: + def validate_configuration(cls, tenant_id: str, config: dict[str, Any], app_mode: AppMode) -> AppModelConfigDict: match app_mode: case AppMode.CHAT: return ChatAppConfigManager.config_validate(tenant_id, config) diff --git a/api/services/app_service.py b/api/services/app_service.py index 87d52a3159..a046b909b3 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -4,8 +4,6 @@ from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination -from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from configs import dify_config @@ -17,6 +15,8 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) class AppService: - def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Pagination | None: + def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict[str, Any]) -> Pagination | None: """ Get app list with pagination :param user_id: user id @@ -78,7 +78,7 @@ class AppService: return app_models - def create_app(self, tenant_id: str, args: dict, account: Account) -> App: + def create_app(self, tenant_id: str, args: dict[str, Any], account: Account) -> App: """ Create app :param tenant_id: tenant id @@ -303,17 +303,22 @@ class AppService: return app - def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: + def update_app_icon( + self, app: App, icon: str, icon_background: str, icon_type: IconType | str | None = None + ) -> App: """ Update app icon :param app: App instance :param icon: new icon :param icon_background: new icon_background + :param icon_type: new icon type :return: App instance """ assert current_user is not None app.icon = icon app.icon_background = icon_background + if icon_type is not None: + app.icon_type = icon_type if isinstance(icon_type, IconType) else IconType(icon_type) app.updated_by = current_user.id app.updated_at = naive_utc_now() db.session.commit() @@ -389,7 +394,7 @@ class AppService: """ app_mode = AppMode.value_of(app_model.mode) - meta: dict = {"tool_icons": {}} + meta: dict[str, Any] = {"tool_icons": {}} if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 0842e9d3e7..6e9d6b1c73 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -5,11 +5,10 @@ like stopping tasks, handling both legacy Redis flag mechanism and new GraphEngine command channel mechanism. """ -from graphon.graph_engine.manager import GraphEngineManager - from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_redis import redis_client +from graphon.graph_engine.manager import GraphEngineManager from models.model import AppMode diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index 55ae1e03b1..ceda30e950 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -7,7 +7,7 @@ with support for different subscription tiers, rate limiting, and execution trac import json from datetime import UTC, datetime -from typing import Any, Union +from typing import Any from celery.result import AsyncResult from sqlalchemy import select @@ -22,6 +22,7 @@ from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError +from services.quota_service import QuotaService, unlimited from services.workflow.entities import AsyncTriggerResponse, TriggerData, WorkflowTaskData from services.workflow.queue_dispatcher import QueueDispatcherManager, QueuePriority from services.workflow_service import WorkflowService @@ -50,7 +51,7 @@ class AsyncWorkflowService: @classmethod def trigger_workflow_async( - cls, session: Session, user: Union[Account, EndUser], trigger_data: TriggerData + cls, session: Session, user: Account | EndUser, trigger_data: TriggerData ) -> AsyncTriggerResponse: """ Universal entry point for async workflow execution - THIS METHOD WILL NOT BLOCK @@ -88,7 +89,10 @@ class AsyncWorkflowService: raise WorkflowNotFoundError(f"App not found: {trigger_data.app_id}") # 2. Get workflow - workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id) + workflow = cls._get_workflow(workflow_service, app_model, trigger_data.workflow_id, session=session) + + # commit read only session before starting the billig rpc call + session.commit() # 3. Get dispatcher based on tenant subscription dispatcher = dispatcher_manager.get_dispatcher(trigger_data.tenant_id) @@ -131,9 +135,10 @@ class AsyncWorkflowService: trigger_log = trigger_log_repo.create(trigger_log) session.commit() - # 7. Check and consume quota + # 7. Reserve quota (commit after successful dispatch) + quota_charge = unlimited() try: - QuotaType.WORKFLOW.consume(trigger_data.tenant_id) + quota_charge = QuotaService.reserve(QuotaType.WORKFLOW, trigger_data.tenant_id) except QuotaExceededError as e: # Update trigger log status trigger_log.status = WorkflowTriggerStatus.RATE_LIMITED @@ -153,13 +158,18 @@ class AsyncWorkflowService: # 9. Dispatch to appropriate queue task_data_dict = task_data.model_dump(mode="json") - task: AsyncResult[Any] | None = None - if queue_name == QueuePriority.PROFESSIONAL: - task = execute_workflow_professional.delay(task_data_dict) - elif queue_name == QueuePriority.TEAM: - task = execute_workflow_team.delay(task_data_dict) - else: # SANDBOX - task = execute_workflow_sandbox.delay(task_data_dict) + try: + task: AsyncResult[Any] | None = None + if queue_name == QueuePriority.PROFESSIONAL: + task = execute_workflow_professional.delay(task_data_dict) + elif queue_name == QueuePriority.TEAM: + task = execute_workflow_team.delay(task_data_dict) + else: # SANDBOX + task = execute_workflow_sandbox.delay(task_data_dict) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise # 10. Update trigger log with task info trigger_log.status = WorkflowTriggerStatus.QUEUED @@ -177,7 +187,7 @@ class AsyncWorkflowService: @classmethod def reinvoke_trigger( - cls, session: Session, user: Union[Account, EndUser], workflow_trigger_log_id: str + cls, session: Session, user: Account | EndUser, workflow_trigger_log_id: str ) -> AsyncTriggerResponse: """ Re-invoke a previously failed or rate-limited trigger - THIS METHOD WILL NOT BLOCK @@ -295,13 +305,21 @@ class AsyncWorkflowService: return [log.to_dict() for log in logs] @staticmethod - def _get_workflow(workflow_service: WorkflowService, app_model: App, workflow_id: str | None = None) -> Workflow: + def _get_workflow( + workflow_service: WorkflowService, + app_model: App, + workflow_id: str | None = None, + session: Session | None = None, + ) -> Workflow: """ Get workflow for the app Args: app_model: App model instance workflow_id: Optional specific workflow ID + session: Reuse this SQLAlchemy session for the lookup when provided, + so the caller's explicit session bears the connection cost + instead of Flask's request-scoped ``db.session``. Returns: Workflow instance @@ -311,12 +329,12 @@ class AsyncWorkflowService: """ if workflow_id: # Get specific published workflow - workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id) + workflow = workflow_service.get_published_workflow_by_id(app_model, workflow_id, session=session) if not workflow: raise WorkflowNotFoundError(f"Published workflow not found: {workflow_id}") else: # Get default published workflow - workflow = workflow_service.get_published_workflow(app_model) + workflow = workflow_service.get_published_workflow(app_model, session=session) if not workflow: raise WorkflowNotFoundError(f"No published workflow found for app: {app_model.id}") diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1c7027efb4..60948e652b 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,12 +5,12 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context -from graphon.model_runtime.entities.model_entities import ModelType from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message from services.errors.audio import ( diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 3282dcfb11..36b1517056 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -1,4 +1,5 @@ import json +from typing import Any from sqlalchemy import select @@ -19,7 +20,7 @@ class ApiKeyAuthService: return data_source_api_key_bindings @staticmethod - def create_provider_auth(tenant_id: str, args: dict): + def create_provider_auth(tenant_id: str, args: dict[str, Any]): auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() if auth_result: # Encrypt the api key diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 735b22aa4c..c0e23cdc6f 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -2,7 +2,7 @@ import json import logging import os from collections.abc import Sequence -from typing import Literal, NotRequired, TypedDict +from typing import Any, Literal, NotRequired, TypedDict import httpx from pydantic import TypeAdapter @@ -32,6 +32,50 @@ class SubscriptionPlan(TypedDict): expiration_date: int +class QuotaReserveResult(TypedDict): + reservation_id: str + available: int + reserved: int + + +class QuotaCommitResult(TypedDict): + available: int + reserved: int + refunded: int + + +class QuotaReleaseResult(TypedDict): + available: int + reserved: int + released: int + + +_quota_reserve_adapter = TypeAdapter(QuotaReserveResult) +_quota_commit_adapter = TypeAdapter(QuotaCommitResult) +_quota_release_adapter = TypeAdapter(QuotaReleaseResult) + + +class _TenantFeatureQuota(TypedDict): + usage: int + limit: int + reset_date: NotRequired[int] + + +class TenantFeatureQuotaInfo(TypedDict): + """Response of /quota/info. + + NOTE (hj24): + - Same convention as BillingInfo: billing may return int fields as str, + always keep non-strict mode to auto-coerce. + """ + + trigger_event: _TenantFeatureQuota + api_rate_limit: _TenantFeatureQuota + + +_tenant_feature_quota_info_adapter = TypeAdapter(TenantFeatureQuotaInfo) + + class _BillingQuota(TypedDict): size: int limit: int @@ -149,11 +193,63 @@ class BillingService: @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): + """Deprecated: Use get_quota_info instead.""" params = {"tenant_id": tenant_id} - usage_info = cls._send_request("GET", "/tenant-feature-usage/info", params=params) return usage_info + @classmethod + def get_quota_info(cls, tenant_id: str) -> TenantFeatureQuotaInfo: + params = {"tenant_id": tenant_id} + return _tenant_feature_quota_info_adapter.validate_python( + cls._send_request("GET", "/quota/info", params=params) + ) + + @classmethod + def quota_reserve( + cls, tenant_id: str, feature_key: str, request_id: str, amount: int = 1, meta: dict | None = None + ) -> QuotaReserveResult: + """Reserve quota before task execution.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "request_id": request_id, + "amount": amount, + } + if meta: + payload["meta"] = meta + return _quota_reserve_adapter.validate_python(cls._send_request("POST", "/quota/reserve", json=payload)) + + @classmethod + def quota_commit( + cls, tenant_id: str, feature_key: str, reservation_id: str, actual_amount: int, meta: dict | None = None + ) -> QuotaCommitResult: + """Commit a reservation with actual consumption.""" + payload: dict = { + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + "actual_amount": actual_amount, + } + if meta: + payload["meta"] = meta + return _quota_commit_adapter.validate_python(cls._send_request("POST", "/quota/commit", json=payload)) + + @classmethod + def quota_release(cls, tenant_id: str, feature_key: str, reservation_id: str) -> QuotaReleaseResult: + """Release a reservation (cancel, return frozen quota).""" + return _quota_release_adapter.validate_python( + cls._send_request( + "POST", + "/quota/release", + json={ + "tenant_id": tenant_id, + "feature_key": feature_key, + "reservation_id": reservation_id, + }, + ) + ) + @classmethod def get_knowledge_rate_limit(cls, tenant_id: str) -> KnowledgeRateLimitDict: params = {"tenant_id": tenant_id} @@ -541,7 +637,7 @@ class BillingService: start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional. Returns {"notification_id": str}. """ - payload: dict = { + payload: dict[str, Any] = { "contents": contents, "frequency": frequency, "status": status, diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index b0f7efaccd..dcc93b4b0f 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,14 +6,14 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app -from graphon.model_runtime.utils.encoders import jsonable_encoder -from sqlalchemy import select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.account import Tenant from models.model import ( App, @@ -62,13 +62,11 @@ class ClearFreePlanTenantExpiredLogs: for model, table_name in related_tables: # Query records related to expired messages - records = ( - session.query(model) - .where( + records = session.scalars( + select(model).where( model.message_id.in_(batch_message_ids), # type: ignore ) - .all() - ) + ).all() if len(records) == 0: continue @@ -103,9 +101,13 @@ class ClearFreePlanTenantExpiredLogs: except Exception: logger.exception("Failed to save %s records", table_name) - session.query(model).where( - model.id.in_(record_ids), # type: ignore - ).delete(synchronize_session=False) + session.execute( + delete(model) + .where( + model.id.in_(record_ids), # type: ignore + ) + .execution_options(synchronize_session=False) + ) click.echo( click.style( @@ -121,15 +123,14 @@ class ClearFreePlanTenantExpiredLogs: app_ids = [app.id for app in apps] while True: with sessionmaker(bind=db.engine, autoflush=False).begin() as session: - messages = ( - session.query(Message) + messages = session.scalars( + select(Message) .where( Message.app_id.in_(app_ids), Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(messages) == 0: break @@ -147,9 +148,9 @@ class ClearFreePlanTenantExpiredLogs: message_ids = [message.id for message in messages] # delete messages - session.query(Message).where( - Message.id.in_(message_ids), - ).delete(synchronize_session=False) + session.execute( + delete(Message).where(Message.id.in_(message_ids)).execution_options(synchronize_session=False) + ) cls._clear_message_related_tables(session, tenant_id, message_ids) @@ -161,15 +162,14 @@ class ClearFreePlanTenantExpiredLogs: while True: with sessionmaker(bind=db.engine, autoflush=False).begin() as session: - conversations = ( - session.query(Conversation) + conversations = session.scalars( + select(Conversation) .where( Conversation.app_id.in_(app_ids), Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(conversations) == 0: break @@ -186,9 +186,11 @@ class ClearFreePlanTenantExpiredLogs: ) conversation_ids = [conversation.id for conversation in conversations] - session.query(Conversation).where( - Conversation.id.in_(conversation_ids), - ).delete(synchronize_session=False) + session.execute( + delete(Conversation) + .where(Conversation.id.in_(conversation_ids)) + .execution_options(synchronize_session=False) + ) click.echo( click.style( @@ -293,15 +295,14 @@ class ClearFreePlanTenantExpiredLogs: while True: with sessionmaker(bind=db.engine, autoflush=False).begin() as session: - workflow_app_logs = ( - session.query(WorkflowAppLog) + workflow_app_logs = session.scalars( + select(WorkflowAppLog) .where( WorkflowAppLog.tenant_id == tenant_id, WorkflowAppLog.created_at < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) - .all() - ) + ).all() if len(workflow_app_logs) == 0: break @@ -321,8 +322,10 @@ class ClearFreePlanTenantExpiredLogs: workflow_app_log_ids = [workflow_app_log.id for workflow_app_log in workflow_app_logs] # delete workflow app logs - session.query(WorkflowAppLog).where(WorkflowAppLog.id.in_(workflow_app_log_ids)).delete( - synchronize_session=False + session.execute( + delete(WorkflowAppLog) + .where(WorkflowAppLog.id.in_(workflow_app_log_ids)) + .execution_options(synchronize_session=False) ) click.echo( @@ -344,7 +347,7 @@ class ClearFreePlanTenantExpiredLogs: current_time = started_at with sessionmaker(db.engine).begin() as session: - total_tenant_count = session.query(Tenant.id).count() + total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0 click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) @@ -409,9 +412,12 @@ class ClearFreePlanTenantExpiredLogs: tenant_count = 0 for test_interval in test_intervals: tenant_count = ( - session.query(Tenant.id) - .where(Tenant.created_at.between(current_time, current_time + test_interval)) - .count() + session.scalar( + select(func.count(Tenant.id)).where( + Tenant.created_at.between(current_time, current_time + test_interval) + ) + ) + or 0 ) if tenant_count <= 100: interval = test_interval @@ -433,8 +439,8 @@ class ClearFreePlanTenantExpiredLogs: batch_end = min(current_time + interval, ended_at) - rs = ( - session.query(Tenant.id) + rs = session.execute( + select(Tenant.id) .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index f5085af59b..ee8a1c4edd 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -3,7 +3,6 @@ import logging from collections.abc import Callable, Sequence from typing import Any -from graphon.variables.types import SegmentType from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -13,6 +12,7 @@ from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db from factories import variable_factory +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index 95a8951951..287d513f48 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ -from graphon.variables.variables import VariableBase from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from graphon.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 16788300d3..2d210db121 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -29,14 +29,15 @@ class CreditPoolService: @classmethod def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None: """get tenant credit pool""" - return db.session.scalar( - select(TenantCreditPool) - .where( - TenantCreditPool.tenant_id == tenant_id, - TenantCreditPool.pool_type == pool_type, + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + return session.scalar( + select(TenantCreditPool) + .where( + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.pool_type == pool_type, + ) + .limit(1) ) - .limit(1) - ) @classmethod def check_credits_available( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3e952059ac..eef38f1ce2 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,9 +10,6 @@ from collections.abc import Sequence from typing import Any, Literal, TypedDict, cast import sqlalchemy as sa -from graphon.file import helpers as file_helpers -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from redis.exceptions import LockNotOwnedError from sqlalchemy import delete, exists, func, select, update from sqlalchemy.orm import Session, sessionmaker @@ -31,6 +28,9 @@ from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.file import helpers as file_helpers +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel from libs import helper from libs.datetime_utils import naive_utc_now from libs.login import current_user @@ -233,7 +233,7 @@ class DatasetService: embedding_model_provider: str | None = None, embedding_model_name: str | None = None, retrieval_model: RetrievalModel | None = None, - summary_index_setting: dict | None = None, + summary_index_setting: dict[str, Any] | None = None, ): # check if dataset name already exists if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)): @@ -528,6 +528,8 @@ class DatasetService: raise ValueError("External knowledge id is required.") if not external_knowledge_api_id: raise ValueError("External knowledge api id is required.") + # Ensure the referenced external API template exists and belongs to the dataset tenant. + ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id, dataset.tenant_id) # Update metadata fields dataset.updated_by = user.id if user else None dataset.updated_at = naive_utc_now() @@ -552,8 +554,8 @@ class DatasetService: external_knowledge_api_id: External knowledge API identifier """ with sessionmaker(db.engine).begin() as session: - external_knowledge_binding = ( - session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() + external_knowledge_binding = session.scalar( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1) ) if not external_knowledge_binding: @@ -1454,15 +1456,17 @@ class DocumentService: document_id_list: list[str] = [str(document_id) for document_id in document_ids] with session_factory.create_session() as session: - updated_count = ( - session.query(Document) - .filter( + result = session.execute( + update(Document) + .where( Document.id.in_(document_id_list), Document.dataset_id == dataset_id, Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) - .update({Document.need_summary: need_summary}, synchronize_session=False) + .values(need_summary=need_summary) + .execution_options(synchronize_session=False) ) + updated_count = result.rowcount # type: ignore[union-attr,attr-defined] session.commit() logger.info( "Updated need_summary to %s for %d documents in dataset %s", @@ -2489,7 +2493,7 @@ class DocumentService: data_source_type: str, document_form: str, document_language: str, - data_source_info: dict, + data_source_info: dict[str, Any], created_from: str, position: int, account: Account, @@ -2822,6 +2826,10 @@ class DocumentService: knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) + if knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL: + if not knowledge_config.process_rule.rules.parent_mode: + knowledge_config.process_rule.rules.parent_mode = "paragraph" + if not knowledge_config.process_rule.rules.segmentation: raise ValueError("Process rule segmentation is required") @@ -2842,7 +2850,7 @@ class DocumentService: raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod - def estimate_args_validate(cls, args: dict): + def estimate_args_validate(cls, args: dict[str, Any]): if "info_list" not in args or not args["info_list"]: raise ValueError("Data source info is required") @@ -3124,7 +3132,7 @@ class DocumentService: class SegmentService: @classmethod - def segment_create_args_validate(cls, args: dict, document: Document): + def segment_create_args_validate(cls, args: dict[str, Any], document: Document): if document.doc_form == IndexStructureType.QA_INDEX: if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") @@ -3141,7 +3149,7 @@ class SegmentService: raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}") @classmethod - def create_segment(cls, args: dict, document: Document, dataset: Dataset): + def create_segment(cls, args: dict[str, Any], document: Document, dataset: Dataset): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -3740,6 +3748,7 @@ class SegmentService: ChildChunk.segment_id == segment.id, ) ) + assert current_user.current_tenant_id child_chunk = ChildChunk( tenant_id=current_user.current_tenant_id, dataset_id=dataset.id, @@ -3750,7 +3759,7 @@ class SegmentService: index_node_hash=index_node_hash, content=content, word_count=len(content), - type="customized", + type=SegmentType.CUSTOMIZED, created_by=current_user.id, ) db.session.add(child_chunk) @@ -3810,6 +3819,7 @@ class SegmentService: if new_child_chunks_args: child_chunk_count = len(child_chunks) for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1): + assert current_user.current_tenant_id index_node_id = str(uuid.uuid4()) index_node_hash = helper.generate_text_hash(args.content) child_chunk = ChildChunk( @@ -3822,7 +3832,7 @@ class SegmentService: index_node_hash=index_node_hash, content=args.content, word_count=len(args.content), - type="customized", + type=SegmentType.CUSTOMIZED, created_by=current_user.id, ) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index d5f8cd30bd..416bc8cef9 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,8 +3,7 @@ import time from collections.abc import Mapping from typing import Any -from graphon.model_runtime.entities.provider_entities import FormType -from sqlalchemy import func, select +from sqlalchemy import delete, func, select, update from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -18,6 +17,7 @@ from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.provider_entities import FormType from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService @@ -54,11 +54,13 @@ class DatasourceProviderService: remove oauth custom client params """ with sessionmaker(bind=db.engine).begin() as session: - session.query(DatasourceOauthTenantParamConfig).filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, - ).delete() + session.execute( + delete(DatasourceOauthTenantParamConfig).where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, + ) + ) def decrypt_datasource_provider_credentials( self, @@ -110,15 +112,21 @@ class DatasourceProviderService: """ with sessionmaker(bind=db.engine).begin() as session: if credential_id: - datasource_provider = ( - session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() + datasource_provider = session.scalar( + select(DatasourceProvider) + .where(DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.id == credential_id) + .limit(1) ) else: - datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + datasource_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .first() + .limit(1) ) if not datasource_provider: return {} @@ -173,12 +181,15 @@ class DatasourceProviderService: get all datasource credentials by provider """ with sessionmaker(bind=db.engine).begin() as session: - datasource_providers = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + datasource_providers = session.scalars( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) .order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc()) - .all() - ) + ).all() if not datasource_providers: return [] current_user = get_current_user() @@ -232,15 +243,15 @@ class DatasourceProviderService: update datasource provider name """ with sessionmaker(bind=db.engine).begin() as session: - target_provider = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - id=credential_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + target_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == credential_id, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if target_provider is None: raise ValueError("provider not found") @@ -250,16 +261,16 @@ class DatasourceProviderService: # check name is exist if ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=name, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == name, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, + ) ) - .count() - > 0 - ): + or 0 + ) > 0: raise ValueError("Authorization name is already exists") target_provider.name = name @@ -273,26 +284,31 @@ class DatasourceProviderService: """ with sessionmaker(bind=db.engine).begin() as session: # get provider - target_provider = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - id=credential_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + target_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == credential_id, + DatasourceProvider.provider == datasource_provider_id.provider_name, + DatasourceProvider.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if target_provider is None: raise ValueError("provider not found") # clear default provider - session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=target_provider.provider, - plugin_id=target_provider.plugin_id, - is_default=True, - ).update({"is_default": False}) + session.execute( + update(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == target_provider.provider, + DatasourceProvider.plugin_id == target_provider.plugin_id, + DatasourceProvider.is_default.is_(True), + ) + .values(is_default=False) + .execution_options(synchronize_session=False) + ) # set new default provider target_provider.is_default = True @@ -302,7 +318,7 @@ class DatasourceProviderService: self, tenant_id: str, datasource_provider_id: DatasourceProviderID, - client_params: dict | None, + client_params: dict[str, Any] | None, enabled: bool | None, ): """ @@ -311,14 +327,14 @@ class DatasourceProviderService: if client_params is None and enabled is None: return with sessionmaker(bind=db.engine).begin() as session: - tenant_oauth_client_params = ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=datasource_provider_id.provider_name, - plugin_id=datasource_provider_id.plugin_id, + tenant_oauth_client_params = session.scalar( + select(DatasourceOauthTenantParamConfig) + .where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthTenantParamConfig.plugin_id == datasource_provider_id.plugin_id, ) - .first() + .limit(1) ) if not tenant_oauth_client_params: @@ -336,7 +352,7 @@ class DatasourceProviderService: original_params = ( encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {} ) - new_params: dict = { + new_params: dict[str, Any] = { key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE) for key, value in client_params.items() } @@ -351,9 +367,14 @@ class DatasourceProviderService: """ with Session(db.engine).no_autoflush as session: return ( - session.query(DatasourceOauthParamConfig) - .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id) - .first() + session.scalar( + select(DatasourceOauthParamConfig) + .where( + DatasourceOauthParamConfig.provider == datasource_provider_id.provider_name, + DatasourceOauthParamConfig.plugin_id == datasource_provider_id.plugin_id, + ) + .limit(1) + ) is not None ) @@ -423,15 +444,15 @@ class DatasourceProviderService: plugin_id = datasource_provider_id.plugin_id with Session(db.engine).no_autoflush as session: # get tenant oauth client params - tenant_oauth_client_params = ( - session.query(DatasourceOauthTenantParamConfig) - .filter_by( - tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - enabled=True, + tenant_oauth_client_params = session.scalar( + select(DatasourceOauthTenantParamConfig) + .where( + DatasourceOauthTenantParamConfig.tenant_id == tenant_id, + DatasourceOauthTenantParamConfig.provider == provider, + DatasourceOauthTenantParamConfig.plugin_id == plugin_id, + DatasourceOauthTenantParamConfig.enabled.is_(True), ) - .first() + .limit(1) ) if tenant_oauth_client_params: encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) @@ -443,8 +464,13 @@ class DatasourceProviderService: is_verified = PluginService.is_plugin_verified(tenant_id, provider_controller.plugin_unique_identifier) if is_verified: # fallback to system oauth client params - oauth_client_params = ( - session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() + oauth_client_params = session.scalar( + select(DatasourceOauthParamConfig) + .where( + DatasourceOauthParamConfig.provider == provider, + DatasourceOauthParamConfig.plugin_id == plugin_id, + ) + .limit(1) ) if oauth_client_params: return oauth_client_params.system_credentials @@ -455,15 +481,13 @@ class DatasourceProviderService: def generate_next_datasource_provider_name( session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType ) -> str: - db_providers = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, + db_providers = session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, ) - .all() - ) + ).all() return generate_incremental_name( [provider.name for provider in db_providers], f"{credential_type.get_name()}", @@ -476,7 +500,7 @@ class DatasourceProviderService: provider_id: DatasourceProviderID, avatar_url: str | None, expire_at: int, - credentials: dict, + credentials: dict[str, Any], credential_id: str, ) -> None: """ @@ -485,8 +509,10 @@ class DatasourceProviderService: with sessionmaker(bind=db.engine).begin() as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.OAUTH2.value}" with redis_client.lock(lock, timeout=20): - target_provider = ( - session.query(DatasourceProvider).filter_by(id=credential_id, tenant_id=tenant_id).first() + target_provider = session.scalar( + select(DatasourceProvider) + .where(DatasourceProvider.id == credential_id, DatasourceProvider.tenant_id == tenant_id) + .limit(1) ) if target_provider is None: raise ValueError("provider not found") @@ -496,25 +522,28 @@ class DatasourceProviderService: db_provider_name = target_provider.name else: name_conflict = ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=db_provider_name, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - auth_type=CredentialType.OAUTH2.value, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == db_provider_name, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + DatasourceProvider.auth_type == CredentialType.OAUTH2.value, + ) ) - .count() + or 0 ) if name_conflict > 0: db_provider_name = generate_incremental_name( [ provider.name - for provider in session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ) + for provider in session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + ) + ).all() ], db_provider_name, ) @@ -537,7 +566,7 @@ class DatasourceProviderService: provider_id: DatasourceProviderID, avatar_url: str | None, expire_at: int, - credentials: dict, + credentials: dict[str, Any], ) -> None: """ add datasource oauth provider @@ -556,25 +585,27 @@ class DatasourceProviderService: ) else: if ( - session.query(DatasourceProvider) - .filter_by( - tenant_id=tenant_id, - name=db_provider_name, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - auth_type=credential_type.value, + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == db_provider_name, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + DatasourceProvider.auth_type == credential_type.value, + ) ) - .count() - > 0 - ): + or 0 + ) > 0: db_provider_name = generate_incremental_name( [ provider.name - for provider in session.query(DatasourceProvider).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ) + for provider in session.scalars( + select(DatasourceProvider).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.provider == provider_id.provider_name, + DatasourceProvider.plugin_id == provider_id.plugin_id, + ) + ).all() ], db_provider_name, ) @@ -603,7 +634,7 @@ class DatasourceProviderService: name: str | None, tenant_id: str, provider_id: DatasourceProviderID, - credentials: dict, + credentials: dict[str, Any], ) -> None: """ validate datasource provider credentials. @@ -627,11 +658,16 @@ class DatasourceProviderService: # check name is exist if ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name) - .count() - > 0 - ): + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.plugin_id == plugin_id, + DatasourceProvider.provider == provider_name, + DatasourceProvider.name == db_provider_name, + ) + ) + or 0 + ) > 0: raise ValueError("Authorization name is already exists") try: @@ -911,28 +947,44 @@ class DatasourceProviderService: return copy_credentials_list def update_datasource_credentials( - self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None + self, + tenant_id: str, + auth_id: str, + provider: str, + plugin_id: str, + credentials: dict[str, Any] | None, + name: str | None, ) -> None: """ update datasource credentials. """ with sessionmaker(bind=db.engine).begin() as session: - datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) - .first() + datasource_provider = session.scalar( + select(DatasourceProvider) + .where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.id == auth_id, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + .limit(1) ) if not datasource_provider: raise ValueError("Datasource provider not found") # update name if name and name != datasource_provider.name: if ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id) - .count() - > 0 - ): + session.scalar( + select(func.count(DatasourceProvider.id)).where( + DatasourceProvider.tenant_id == tenant_id, + DatasourceProvider.name == name, + DatasourceProvider.provider == provider, + DatasourceProvider.plugin_id == plugin_id, + ) + ) + or 0 + ) > 0: raise ValueError("Authorization name is already exists") datasource_provider.name = name diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 5040fcc7e3..bd7758f1c0 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime from typing import TYPE_CHECKING +from cachetools.func import ttl_cache from pydantic import BaseModel, ConfigDict, Field, model_validator from configs import dify_config @@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None: class EnterpriseService: @classmethod + @ttl_cache(ttl=5) def get_info(cls): return EnterpriseRequest.send_request("GET", "/info") diff --git a/api/services/entities/auth_entities.py b/api/services/entities/auth_entities.py index 6b720a4607..e3fb249692 100644 --- a/api/services/entities/auth_entities.py +++ b/api/services/entities/auth_entities.py @@ -1,9 +1,25 @@ +from enum import StrEnum, auto + from pydantic import BaseModel, Field, field_validator from libs.helper import EmailStr from libs.password import valid_password +class LoginFailureReason(StrEnum): + """Bounded reason codes for failed login audit logs.""" + + ACCOUNT_BANNED = auto() + ACCOUNT_IN_FREEZE = auto() + ACCOUNT_NOT_FOUND = auto() + EMAIL_CODE_EMAIL_MISMATCH = auto() + INVALID_CREDENTIALS = auto() + INVALID_EMAIL_CODE = auto() + INVALID_EMAIL_CODE_TOKEN = auto() + INVALID_INVITATION_EMAIL = auto() + LOGIN_RATE_LIMITED = auto() + + class LoginPayloadBase(BaseModel): email: EmailStr password: str diff --git a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py index c9fb1c9e21..110dbe5a5e 100644 --- a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py +++ b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Any, Literal, Union from pydantic import BaseModel @@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel): class ExternalKnowledgeApiSetting(BaseModel): url: str request_method: str - headers: dict | None = None - params: dict | None = None + headers: dict[str, Any] | None = None + params: dict[str, Any] | None = None diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index cb38104e8c..910f54bebc 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,12 +1,18 @@ -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, field_validator from core.rag.entities import Rule +from core.rag.entities.metadata_entities import MetadataFilteringCondition from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod +class RerankingModel(BaseModel): + reranking_provider_name: str | None = None + reranking_model_name: str | None = None + + class NotionIcon(BaseModel): type: str url: str | None = None @@ -53,11 +59,6 @@ class ProcessRule(BaseModel): rules: Rule | None = None -class RerankingModel(BaseModel): - reranking_provider_name: str | None = None - reranking_model_name: str | None = None - - class WeightVectorSetting(BaseModel): vector_weight: float embedding_provider_name: str @@ -83,11 +84,12 @@ class RetrievalModel(BaseModel): score_threshold_enabled: bool score_threshold: float | None = None weights: WeightModel | None = None + metadata_filtering_conditions: MetadataFilteringCondition | None = None class MetaDataConfig(BaseModel): doc_type: str - doc_metadata: dict + doc_metadata: dict[str, Any] class KnowledgeConfig(BaseModel): @@ -97,7 +99,7 @@ class KnowledgeConfig(BaseModel): data_source: DataSource | None = None process_rule: ProcessRule | None = None retrieval_model: RetrievalModel | None = None - summary_index_setting: dict | None = None + summary_index_setting: dict[str, Any] | None = None doc_form: str = "text_model" doc_language: str = "English" embedding_model: str | None = None diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index a360fd2854..7fb7ed12bf 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, field_validator @@ -6,6 +6,24 @@ from core.rag.entities import KeywordSetting, VectorSetting from core.rag.retrieval.retrieval_methods import RetrievalMethod +class RerankingModelConfig(BaseModel): + """ + Reranking Model Config. + """ + + reranking_provider_name: str | None = "" + reranking_model_name: str | None = "" + + +class WeightedScoreConfig(BaseModel): + """ + Weighted score Config. + """ + + vector_setting: VectorSetting | None + keyword_setting: KeywordSetting | None + + class IconInfo(BaseModel): icon: str icon_background: str | None = None @@ -28,24 +46,6 @@ class RagPipelineDatasetCreateEntity(BaseModel): yaml_content: str | None = None -class RerankingModelConfig(BaseModel): - """ - Reranking Model Config. - """ - - reranking_provider_name: str | None = "" - reranking_model_name: str | None = "" - - -class WeightedScoreConfig(BaseModel): - """ - Weighted score Config. - """ - - vector_setting: VectorSetting | None - keyword_setting: KeywordSetting | None - - class RetrievalSetting(BaseModel): """ Retrieval Setting. @@ -73,7 +73,7 @@ class KnowledgeConfiguration(BaseModel): keyword_number: int | None = 10 retrieval_model: RetrievalSetting # add summary index setting - summary_index_setting: dict | None = None + summary_index_setting: dict[str, Any] | None = None @field_validator("embedding_model_provider", mode="before") @classmethod diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index a944ef6acd..6679c08ebd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,15 +1,6 @@ from collections.abc import Sequence from enum import StrEnum -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderHelpEntity, - SimpleProviderEntity, -) from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config @@ -24,6 +15,15 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from models.provider import ProviderType diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 2bf1afba3e..60b457ecd0 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,16 +1,16 @@ import json from copy import deepcopy -from typing import Any, Union, cast +from typing import Any, cast from urllib.parse import urlparse import httpx -from graphon.nodes.http_request.exc import InvalidHttpMethodError from sqlalchemy import func, select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities import MetadataFilteringCondition from extensions.ext_database import db +from graphon.nodes.http_request.exc import InvalidHttpMethodError from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, @@ -47,7 +47,7 @@ class ExternalDatasetService: return external_knowledge_apis.items, external_knowledge_apis.total @classmethod - def validate_api_list(cls, api_settings: dict): + def validate_api_list(cls, api_settings: dict[str, Any]): if not api_settings: raise ValueError("api list is empty") if not api_settings.get("endpoint"): @@ -56,7 +56,7 @@ class ExternalDatasetService: raise ValueError("api_key is required") @staticmethod - def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis: + def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict[str, Any]) -> ExternalKnowledgeApis: settings = args.get("settings") if settings is None: raise ValueError("settings is required") @@ -75,7 +75,7 @@ class ExternalDatasetService: return external_knowledge_api @staticmethod - def check_endpoint_and_api_key(settings: dict): + def check_endpoint_and_api_key(settings: dict[str, Any]): if "endpoint" not in settings or not settings["endpoint"]: raise ValueError("endpoint is required") if "api_key" not in settings or not settings["api_key"]: @@ -178,7 +178,9 @@ class ExternalDatasetService: return external_knowledge_binding @staticmethod - def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict): + def document_create_args_validate( + tenant_id: str, external_knowledge_api_id: str, process_parameter: dict[str, Any] + ): external_knowledge_api = db.session.scalar( select(ExternalKnowledgeApis) .where(ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == tenant_id) @@ -195,9 +197,7 @@ class ExternalDatasetService: raise ValueError(f"{parameter.get('name')} is required") @staticmethod - def process_external_api( - settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]] - ) -> httpx.Response: + def process_external_api(settings: ExternalKnowledgeApiSetting, files: dict[str, Any] | None) -> httpx.Response: """ do http request depending on api bundle """ @@ -224,7 +224,7 @@ class ExternalDatasetService: return response @staticmethod - def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]: + def assembling_headers(authorization: Authorization, headers: dict[str, Any] | None = None) -> dict[str, Any]: authorization = deepcopy(authorization) if headers: headers = deepcopy(headers) @@ -250,11 +250,11 @@ class ExternalDatasetService: return headers @staticmethod - def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting: + def get_external_knowledge_api_settings(settings: dict[str, Any]) -> ExternalKnowledgeApiSetting: return ExternalKnowledgeApiSetting.model_validate(settings) @staticmethod - def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: + def create_external_dataset(tenant_id: str, user_id: str, args: dict[str, Any]) -> Dataset: # check if dataset name already exists if db.session.scalar( select(Dataset).where(Dataset.name == args.get("name"), Dataset.tenant_id == tenant_id).limit(1) @@ -306,7 +306,7 @@ class ExternalDatasetService: tenant_id: str, dataset_id: str, query: str, - external_retrieval_parameters: dict, + external_retrieval_parameters: dict[str, Any], metadata_condition: MetadataFilteringCondition | None = None, ): external_knowledge_binding = db.session.scalar( @@ -319,7 +319,10 @@ class ExternalDatasetService: external_knowledge_api = db.session.scalar( select(ExternalKnowledgeApis) - .where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) + .where( + ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id, + ExternalKnowledgeApis.tenant_id == tenant_id, + ) .limit(1) ) if external_knowledge_api is None or external_knowledge_api.settings is None: diff --git a/api/services/feature_service.py b/api/services/feature_service.py index df653e0ba7..9477c28bf3 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -3,6 +3,7 @@ from enum import StrEnum from pydantic import BaseModel, ConfigDict, Field from configs import dify_config +from constants.dsl_version import CURRENT_APP_DSL_VERSION from enums.cloud_plan import CloudPlan from enums.hosted_provider import HostedTrialProvider from services.billing_service import BillingService @@ -157,6 +158,7 @@ class PluginManagerModel(BaseModel): class SystemFeatureModel(BaseModel): + app_dsl_version: str = "" sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" enable_marketplace: bool = False @@ -164,6 +166,7 @@ class SystemFeatureModel(BaseModel): enable_email_code_login: bool = False enable_email_password_login: bool = True enable_social_oauth_login: bool = False + enable_collaboration_mode: bool = False is_allow_register: bool = False is_allow_create_workspace: bool = False is_email_setup: bool = False @@ -174,6 +177,7 @@ class SystemFeatureModel(BaseModel): enable_change_email: bool = True plugin_manager: PluginManagerModel = PluginManagerModel() trial_models: list[str] = [] + enable_creators_platform: bool = False enable_trial_app: bool = False enable_explore_banner: bool = False @@ -224,6 +228,7 @@ class FeatureService: @classmethod def get_system_features(cls, is_authenticated: bool = False) -> SystemFeatureModel: system_features = SystemFeatureModel() + system_features.app_dsl_version = CURRENT_APP_DSL_VERSION cls._fulfill_system_params_from_env(system_features) @@ -237,6 +242,9 @@ class FeatureService: if dify_config.MARKETPLACE_ENABLED: system_features.enable_marketplace = True + if dify_config.CREATORS_PLATFORM_FEATURES_ENABLED: + system_features.enable_creators_platform = True + return system_features @classmethod @@ -244,6 +252,7 @@ class FeatureService: system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN + system_features.enable_collaboration_mode = dify_config.ENABLE_COLLABORATION_MODE system_features.is_allow_register = dify_config.ALLOW_REGISTER system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" @@ -281,7 +290,7 @@ class FeatureService: def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features_usage_info = BillingService.get_tenant_feature_plan_usage_info(tenant_id) + features_usage_info = BillingService.get_quota_info(tenant_id) features.billing.enabled = billing_info["enabled"] features.billing.subscription.plan = billing_info["subscription"]["plan"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 7443ca3271..f60afe2f19 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -2,13 +2,12 @@ import base64 import hashlib import os import uuid -from collections.abc import Iterator, Sequence +from collections.abc import Generator, Sequence # Changed Iterator to Generator from contextlib import contextmanager, suppress from tempfile import NamedTemporaryFile -from typing import Literal, Union +from typing import Literal from zipfile import ZIP_DEFLATED, ZipFile -from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound @@ -24,6 +23,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType +from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account @@ -52,7 +52,7 @@ class FileService: filename: str, content: bytes, mimetype: str, - user: Union[Account, EndUser], + user: Account | EndUser, source: Literal["datasets"] | None = None, source_url: str = "", ) -> UploadFile: @@ -324,7 +324,7 @@ class FileService: def build_upload_files_zip_tempfile( *, upload_files: Sequence[UploadFile], - ) -> Iterator[str]: + ) -> Generator[str, None, None]: # Changed from Iterator[str] """ Build a ZIP from `UploadFile`s and yield a tempfile path. diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 7e0100212a..2e5987dd28 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,17 +1,16 @@ import json import logging import time -from typing import Any, TypedDict - -from graphon.model_runtime.entities import LLMMode +from typing import Any, TypedDict, cast from core.app.app_config.entities import ModelConfig -from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db +from graphon.model_runtime.entities import LLMMode from models import Account from models.dataset import Dataset, DatasetQuery from models.enums import CreatorUserRole, DatasetQuerySource @@ -37,6 +36,10 @@ default_retrieval_model = { } +class HitTestingRetrievalModelDict(DefaultRetrievalModelDict, total=False): + metadata_filtering_conditions: dict[str, Any] + + class HitTestingService: @classmethod def retrieve( @@ -44,25 +47,26 @@ class HitTestingService: dataset: Dataset, query: str, account: Account, - retrieval_model: dict | None, - external_retrieval_model: dict, + retrieval_model: dict[str, Any] | None, + external_retrieval_model: dict[str, Any], attachment_ids: list | None = None, limit: int = 10, ): start = time.perf_counter() # get retrieval model , if the model is not setting , using default - if not retrieval_model: - retrieval_model = dataset.retrieval_model or default_retrieval_model - assert isinstance(retrieval_model, dict) + resolved_retrieval_model = cast( + HitTestingRetrievalModelDict, + retrieval_model or dataset.retrieval_model or default_retrieval_model, + ) document_ids_filter = None - metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) - if metadata_filtering_conditions and query: + metadata_filtering_conditions_raw = resolved_retrieval_model.get("metadata_filtering_conditions", {}) + if metadata_filtering_conditions_raw and query: dataset_retrieval = DatasetRetrieval() from core.rag.entities import MetadataFilteringCondition - metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions) + metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions_raw) metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( dataset_ids=[dataset.id], @@ -79,19 +83,21 @@ class HitTestingService: if metadata_condition and not document_ids_filter: return cls.compact_retrieve_response(query, []) all_documents = RetrievalService.retrieve( - retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), + retrieval_method=RetrievalMethod( + resolved_retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH) + ), dataset_id=dataset.id, query=query, attachment_ids=attachment_ids, - top_k=retrieval_model.get("top_k", 4), - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] + top_k=resolved_retrieval_model.get("top_k", 4), + score_threshold=resolved_retrieval_model.get("score_threshold", 0.0) + if resolved_retrieval_model["score_threshold_enabled"] else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) - if retrieval_model["reranking_enable"] + reranking_model=resolved_retrieval_model.get("reranking_model", None) + if resolved_retrieval_model["reranking_enable"] else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + reranking_mode=resolved_retrieval_model.get("reranking_mode") or "reranking_model", + weights=resolved_retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, ) @@ -125,8 +131,8 @@ class HitTestingService: dataset: Dataset, query: str, account: Account, - external_retrieval_model: dict | None = None, - metadata_filtering_conditions: dict | None = None, + external_retrieval_model: dict[str, Any] | None = None, + metadata_filtering_conditions: dict[str, Any] | None = None, ): if dataset.provider != "external": return { diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 77576fa4c0..8b4983e5f7 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -4,12 +4,11 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol -from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -18,6 +17,7 @@ from core.workflow.human_input_compat import ( ) from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import VariablePool from libs.email_template_renderer import render_email_template from models import Account, TenantAccountJoin from services.feature_service import FeatureService diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 02a6620fc7..76598d31ac 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -3,12 +3,6 @@ from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any -from graphon.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -17,6 +11,12 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/message_service.py b/api/services/message_service.py index 5b133b0c04..8f5e028d4d 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -1,6 +1,7 @@ +import logging from collections.abc import Sequence +from typing import cast -from graphon.model_runtime.entities.model_entities import ModelType from pydantic import TypeAdapter from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -14,10 +15,20 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.enums import FeedbackFromSource, FeedbackRating -from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, EndUser, Message, MessageFeedback +from models.model import ( + App, + AppMode, + AppModelConfig, + AppModelConfigDict, + EndUser, + Message, + MessageFeedback, + SuggestedQuestionsAfterAnswerConfig, +) from repositories.execution_extra_content_repository import ExecutionExtraContentRepository from repositories.sqlalchemy_execution_extra_content_repository import ( SQLAlchemyExecutionExtraContentRepository, @@ -32,6 +43,7 @@ from services.errors.message import ( from services.workflow_service import WorkflowService _app_model_config_adapter: TypeAdapter[AppModelConfigDict] = TypeAdapter(AppModelConfigDict) +logger = logging.getLogger(__name__) def _create_execution_extra_content_repository() -> ExecutionExtraContentRepository: @@ -252,6 +264,7 @@ class MessageService: ) model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id) + suggested_questions_after_answer_config: SuggestedQuestionsAfterAnswerConfig = {"enabled": False} if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() @@ -271,9 +284,11 @@ class MessageService: if not app_config.additional_features.suggested_questions_after_answer: raise SuggestedQuestionsAfterAnswerDisabledError() - model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, model_type=ModelType.LLM - ) + suggested_questions_after_answer = workflow.features_dict.get("suggested_questions_after_answer") + if isinstance(suggested_questions_after_answer, dict): + suggested_questions_after_answer_config = cast( + SuggestedQuestionsAfterAnswerConfig, suggested_questions_after_answer + ) else: if not conversation.override_model_configs: app_model_config = db.session.scalar( @@ -293,16 +308,14 @@ class MessageService: if not app_model_config: raise ValueError("did not find app model config") - suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict - if suggested_questions_after_answer.get("enabled", False) is False: + suggested_questions_after_answer_config = app_model_config.suggested_questions_after_answer_dict + if suggested_questions_after_answer_config.get("enabled", False) is False: raise SuggestedQuestionsAfterAnswerDisabledError() - model_instance = model_manager.get_model_instance( - tenant_id=app_model.tenant_id, - provider=app_model_config.model_dict["provider"], - model_type=ModelType.LLM, - model=app_model_config.model_dict["name"], - ) + model_instance = model_manager.get_default_model_instance( + tenant_id=app_model.tenant_id, + model_type=ModelType.LLM, + ) # get memory of conversation (read-only) memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) @@ -312,9 +325,17 @@ class MessageService: message_limit=3, ) + instruction_prompt = suggested_questions_after_answer_config.get("prompt") + if not isinstance(instruction_prompt, str) or not instruction_prompt.strip(): + instruction_prompt = None + + configured_model = suggested_questions_after_answer_config.get("model") with measure_time() as timer: questions_sequence = LLMGenerator.generate_suggested_questions_after_answer( - tenant_id=app_model.tenant_id, histories=histories + tenant_id=app_model.tenant_id, + histories=histories, + instruction_prompt=instruction_prompt, + model_config=configured_model, ) questions: list[str] = list(questions_sequence) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 3cce83a975..c269346f5f 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -1,13 +1,7 @@ import json import logging -from typing import Any, TypedDict, Union +from typing import Any, TypedDict -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ModelCredentialSchema, - ProviderCredentialSchema, -) -from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import or_, select from constants import HIDDEN_VALUE @@ -18,6 +12,12 @@ from core.model_manager import LBModelManager from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential @@ -502,7 +502,7 @@ class ModelLoadBalancingService: provider: str, model: str, model_type: str, - credentials: dict, + credentials: dict[str, Any], config_id: str | None = None, ): """ @@ -561,7 +561,7 @@ class ModelLoadBalancingService: provider_configuration: ProviderConfiguration, model_type: ModelType, model: str, - credentials: dict, + credentials: dict[str, Any], load_balancing_model_config: LoadBalancingModelConfig | None = None, model_provider_factory: ModelProviderFactory | None = None, validate: bool = True, @@ -626,7 +626,7 @@ class ModelLoadBalancingService: def _get_credential_schema( self, provider_configuration: ProviderConfiguration - ) -> Union[ModelCredentialSchema, ProviderCredentialSchema]: + ) -> ModelCredentialSchema | ProviderCredentialSchema: """Get form schemas.""" if provider_configuration.provider.model_credential_schema: return provider_configuration.provider.model_credential_schema diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 3f37c9b176..51cda79661 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,10 +1,10 @@ import logging - -from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule +from typing import Any from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager +from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, @@ -168,7 +168,9 @@ class ModelProviderService: model_name=model, ) - def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: + def get_provider_credential( + self, tenant_id: str, provider: str, credential_id: str | None = None + ) -> dict[str, Any] | None: """ get provider credentials. @@ -180,7 +182,7 @@ class ModelProviderService: provider_configuration = self._get_provider_configuration(tenant_id, provider) return provider_configuration.get_provider_credential(credential_id=credential_id) - def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict): + def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict[str, Any]): """ validate provider credentials before saving. @@ -192,7 +194,7 @@ class ModelProviderService: provider_configuration.validate_provider_credentials(credentials) def create_provider_credential( - self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None + self, tenant_id: str, provider: str, credentials: dict[str, Any], credential_name: str | None ) -> None: """ Create and save new provider credentials. @@ -210,7 +212,7 @@ class ModelProviderService: self, tenant_id: str, provider: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ) -> None: @@ -254,7 +256,7 @@ class ModelProviderService: def get_model_credential( self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None - ) -> dict | None: + ) -> dict[str, Any] | None: """ Retrieve model-specific credentials. @@ -270,7 +272,9 @@ class ModelProviderService: model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id ) - def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict): + def validate_model_credentials( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict[str, Any] + ): """ validate model credentials. @@ -287,7 +291,13 @@ class ModelProviderService: ) def create_model_credential( - self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + credentials: dict[str, Any], + credential_name: str | None, ) -> None: """ create and save model credentials. @@ -314,7 +324,7 @@ class ModelProviderService: provider: str, model_type: str, model: str, - credentials: dict, + credentials: dict[str, Any], credential_id: str, credential_name: str | None, ) -> None: diff --git a/api/services/operation_service.py b/api/services/operation_service.py index c05e9d555c..903efd26ae 100644 --- a/api/services/operation_service.py +++ b/api/services/operation_service.py @@ -1,8 +1,22 @@ import os +from typing import TypedDict import httpx +class UtmInfo(TypedDict, total=False): + """Expected shape of the utm_info dict passed to record_utm. + + All fields are optional; missing keys default to an empty string. + """ + + utm_source: str + utm_medium: str + utm_campaign: str + utm_content: str + utm_term: str + + class OperationService: base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @@ -17,7 +31,7 @@ class OperationService: return response.json() @classmethod - def record_utm(cls, tenant_id: str, utm_info: dict): + def record_utm(cls, tenant_id: str, utm_info: UtmInfo): params = { "tenant_id": tenant_id, "utm_source": utm_info.get("utm_source", ""), diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 0db3d3efec..3ad42faf24 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,3 +1,5 @@ +from typing import Any + from sqlalchemy import select from core.ops.entities.config_entity import BaseTracingConfig @@ -135,7 +137,7 @@ class OpsService: return trace_config_data.to_dict() @classmethod - def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict): + def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict[str, Any]): """ Create tracing app config :param app_id: app id @@ -210,7 +212,7 @@ class OpsService: return {"result": "success"} @classmethod - def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict): + def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict[str, Any]): """ Update tracing app config :param app_id: app id diff --git a/api/services/plugin/endpoint_service.py b/api/services/plugin/endpoint_service.py index 11b8e0a3d9..1727cd7abd 100644 --- a/api/services/plugin/endpoint_service.py +++ b/api/services/plugin/endpoint_service.py @@ -1,9 +1,13 @@ +from typing import Any + from core.plugin.impl.endpoint import PluginEndpointClient class EndpointService: @classmethod - def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict): + def create_endpoint( + cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict[str, Any] + ): return PluginEndpointClient().create_endpoint( tenant_id=tenant_id, user_id=user_id, @@ -32,7 +36,7 @@ class EndpointService: ) @classmethod - def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict): + def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]): return PluginEndpointClient().update_endpoint( tenant_id=tenant_id, user_id=user_id, diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 88dec062a0..789b5fa5b7 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,5 +1,6 @@ import json import uuid +from typing import Any from core.plugin.impl.base import BasePluginClient from extensions.ext_redis import redis_client @@ -16,7 +17,7 @@ class OAuthProxyService(BasePluginClient): tenant_id: str, plugin_id: str, provider: str, - extra_data: dict = {}, + extra_data: dict[str, Any] = {}, credential_id: str | None = None, ): """ diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py index 9bb0ab6ae2..b96b8140ac 100644 --- a/api/services/plugin/plugin_auto_upgrade_service.py +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -23,7 +23,7 @@ class PluginAutoUpgradeService: exclude_plugins: list[str], include_plugins: list[str], ) -> bool: - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): exist_strategy = session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -50,7 +50,7 @@ class PluginAutoUpgradeService: @staticmethod def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): exist_strategy = session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index d6f6ee8086..43a726b100 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -13,6 +13,7 @@ import sqlalchemy as sa import tqdm from flask import Flask, current_app from pydantic import TypeAdapter +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity @@ -66,7 +67,7 @@ class PluginMigration: current_time = started_at with Session(db.engine) as session: - total_tenant_count = session.query(Tenant.id).count() + total_tenant_count = session.scalar(select(func.count(Tenant.id))) or 0 click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white")) @@ -123,9 +124,12 @@ class PluginMigration: tenant_count = 0 for test_interval in test_intervals: tenant_count = ( - session.query(Tenant.id) - .where(Tenant.created_at.between(current_time, current_time + test_interval)) - .count() + session.scalar( + select(func.count(Tenant.id)).where( + Tenant.created_at.between(current_time, current_time + test_interval) + ) + ) + or 0 ) if tenant_count <= 100: interval = test_interval @@ -147,8 +151,8 @@ class PluginMigration: batch_end = min(current_time + interval, ended_at) - rs = ( - session.query(Tenant.id) + rs = session.execute( + select(Tenant.id) .where(Tenant.created_at.between(current_time, batch_end)) .order_by(Tenant.created_at) ) @@ -235,7 +239,7 @@ class PluginMigration: Extract tool tables. """ with Session(db.engine) as session: - rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all() + rs = session.scalars(select(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id)).all() result = [] for row in rs: result.append(ToolProviderID(row.provider).plugin_id) @@ -249,7 +253,7 @@ class PluginMigration: """ with Session(db.engine) as session: - rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all() + rs = session.scalars(select(Workflow).where(Workflow.tenant_id == tenant_id)).all() result = [] for row in rs: graph = row.graph_dict @@ -272,7 +276,7 @@ class PluginMigration: Extract app tables. """ with Session(db.engine) as session: - apps = session.query(App).where(App.tenant_id == tenant_id).all() + apps = session.scalars(select(App).where(App.tenant_id == tenant_id)).all() if not apps: return [] @@ -280,7 +284,7 @@ class PluginMigration: app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT ] - rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all() + rs = session.scalars(select(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids))).all() result = [] for row in rs: agent_config = row.agent_mode_dict diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py index 0d2a70acbd..3cca4268d0 100644 --- a/api/services/plugin/plugin_permission_service.py +++ b/api/services/plugin/plugin_permission_service.py @@ -1,14 +1,13 @@ from sqlalchemy import select -from sqlalchemy.orm import sessionmaker -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.account import TenantPluginPermission class PluginPermissionService: @staticmethod def get_permission(tenant_id: str) -> TenantPluginPermission | None: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: return session.scalar( select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) ) @@ -19,7 +18,7 @@ class PluginPermissionService: install_permission: TenantPluginPermission.InstallPermission, debug_permission: TenantPluginPermission.DebugPermission, ): - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session, session.begin(): permission = session.scalar( select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).limit(1) ) diff --git a/api/services/quota_service.py b/api/services/quota_service.py new file mode 100644 index 0000000000..4c784315c7 --- /dev/null +++ b/api/services/quota_service.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from enums.quota_type import QuotaType + +logger = logging.getLogger(__name__) + + +@dataclass +class QuotaCharge: + """ + Result of a quota reservation (Reserve phase). + + Lifecycle: + charge = QuotaService.consume(QuotaType.TRIGGER, tenant_id) + try: + do_work() + charge.commit() # Confirm consumption + except: + charge.refund() # Release frozen quota + + If neither commit() nor refund() is called, the billing system's + cleanup CronJob will auto-release the reservation within ~75 seconds. + """ + + success: bool + charge_id: str | None # reservation_id + _quota_type: QuotaType + _tenant_id: str | None = None + _feature_key: str | None = None + _amount: int = 0 + _committed: bool = field(default=False, repr=False) + + def commit(self, actual_amount: int | None = None) -> None: + """ + Confirm the consumption with actual amount. + + Args: + actual_amount: Actual amount consumed. Defaults to the reserved amount. + If less than reserved, the difference is refunded automatically. + """ + if self._committed or not self.charge_id or not self._tenant_id or not self._feature_key: + return + + try: + from services.billing_service import BillingService + + amount = actual_amount if actual_amount is not None else self._amount + BillingService.quota_commit( + tenant_id=self._tenant_id, + feature_key=self._feature_key, + reservation_id=self.charge_id, + actual_amount=amount, + ) + self._committed = True + logger.debug( + "Committed %s quota for tenant %s, reservation_id: %s, amount: %d", + self._quota_type, + self._tenant_id, + self.charge_id, + amount, + ) + except Exception: + logger.exception("Failed to commit quota, reservation_id: %s", self.charge_id) + + def refund(self) -> None: + """ + Release the reserved quota (cancel the charge). + + Safe to call even if: + - charge failed or was disabled (charge_id is None) + - already committed (Release after Commit is a no-op) + - already refunded (idempotent) + + This method guarantees no exceptions will be raised. + """ + if not self.charge_id or not self._tenant_id or not self._feature_key: + return + + QuotaService.release(self._quota_type, self.charge_id, self._tenant_id, self._feature_key) + + +def unlimited() -> QuotaCharge: + from enums.quota_type import QuotaType + + return QuotaCharge(success=True, charge_id=None, _quota_type=QuotaType.UNLIMITED) + + +class QuotaService: + """Orchestrates quota reserve / commit / release lifecycle via BillingService.""" + + @staticmethod + def consume(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve + immediate Commit (one-shot mode). + + The returned QuotaCharge supports .refund() which calls Release. + For two-phase usage (e.g. streaming), use reserve() directly. + """ + charge = QuotaService.reserve(quota_type, tenant_id, amount) + if charge.success and charge.charge_id: + charge.commit() + return charge + + @staticmethod + def reserve(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> QuotaCharge: + """ + Reserve quota before task execution (Reserve phase only). + + The caller MUST call charge.commit() after the task succeeds, + or charge.refund() if the task fails. + + Raises: + QuotaExceededError: When quota is insufficient + """ + from services.billing_service import BillingService + from services.errors.app import QuotaExceededError + + if not dify_config.BILLING_ENABLED: + logger.debug("Billing disabled, allowing request for %s", tenant_id) + return QuotaCharge(success=True, charge_id=None, _quota_type=quota_type) + + logger.info("Reserving %d %s quota for tenant %s", amount, quota_type.value, tenant_id) + + if amount <= 0: + raise ValueError("Amount to reserve must be greater than 0") + + request_id = str(uuid.uuid4()) + feature_key = quota_type.billing_key + + try: + reserve_resp = BillingService.quota_reserve( + tenant_id=tenant_id, + feature_key=feature_key, + request_id=request_id, + amount=amount, + ) + + reservation_id = reserve_resp.get("reservation_id") + if not reservation_id: + logger.warning( + "Reserve returned no reservation_id for %s, feature %s, response: %s", + tenant_id, + quota_type.value, + reserve_resp, + ) + raise QuotaExceededError(feature=quota_type.value, tenant_id=tenant_id, required=amount) + + logger.debug( + "Reserved %d %s quota for tenant %s, reservation_id: %s", + amount, + quota_type.value, + tenant_id, + reservation_id, + ) + return QuotaCharge( + success=True, + charge_id=reservation_id, + _quota_type=quota_type, + _tenant_id=tenant_id, + _feature_key=feature_key, + _amount=amount, + ) + + except QuotaExceededError: + raise + except ValueError: + raise + except Exception: + logger.exception("Failed to reserve quota for %s, feature %s", tenant_id, quota_type.value) + return unlimited() + + @staticmethod + def check(quota_type: QuotaType, tenant_id: str, amount: int = 1) -> bool: + if not dify_config.BILLING_ENABLED: + return True + + if amount <= 0: + raise ValueError("Amount to check must be greater than 0") + + try: + remaining = QuotaService.get_remaining(quota_type, tenant_id) + return remaining >= amount if remaining != -1 else True + except Exception: + logger.exception("Failed to check quota for %s, feature %s", tenant_id, quota_type.value) + return True + + @staticmethod + def release(quota_type: QuotaType, reservation_id: str, tenant_id: str, feature_key: str) -> None: + """Release a reservation. Guarantees no exceptions.""" + try: + from services.billing_service import BillingService + + if not dify_config.BILLING_ENABLED: + return + + if not reservation_id: + return + + logger.info("Releasing %s quota, reservation_id: %s", quota_type.value, reservation_id) + BillingService.quota_release( + tenant_id=tenant_id, + feature_key=feature_key, + reservation_id=reservation_id, + ) + except Exception: + logger.exception("Failed to release quota, reservation_id: %s", reservation_id) + + @staticmethod + def get_remaining(quota_type: QuotaType, tenant_id: str) -> int: + from services.billing_service import BillingService + + try: + usage_info = BillingService.get_quota_info(tenant_id) + if isinstance(usage_info, dict): + feature_info = usage_info.get(quota_type.billing_key, {}) + if isinstance(feature_info, dict): + limit = feature_info.get("limit", 0) + usage = feature_info.get("usage", 0) + if limit == -1: + return -1 + return max(0, limit - usage) + return 0 + except Exception: + logger.exception("Failed to get remaining quota for %s, feature %s", tenant_id, quota_type.value) + return -1 diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 10e89b1dba..56bc785958 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Union +from typing import Any from configs import dify_config from core.app.apps.pipeline.pipeline_generator import PipelineGenerator @@ -17,7 +17,7 @@ class PipelineGenerateService: def generate( cls, pipeline: Pipeline, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, streaming: bool = True, diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index 24baeb73b5..8c9a81af87 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -1,6 +1,7 @@ import json from os import path from pathlib import Path +from typing import Any from flask import current_app @@ -13,21 +14,21 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json """ - builtin_data: dict | None = None + builtin_data: dict[str, Any] | None = None def get_type(self) -> str: return PipelineTemplateType.BUILTIN - def get_pipeline_templates(self, language: str) -> dict: + def get_pipeline_templates(self, language: str) -> dict[str, Any]: result = self.fetch_pipeline_templates_from_builtin(language) return result - def get_pipeline_template_detail(self, template_id: str): + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: result = self.fetch_pipeline_template_detail_from_builtin(template_id) return result @classmethod - def _get_builtin_data(cls) -> dict: + def _get_builtin_data(cls) -> dict[str, Any]: """ Get builtin data. :return: @@ -43,21 +44,21 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return cls.builtin_data or {} @classmethod - def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict: + def fetch_pipeline_templates_from_builtin(cls, language: str) -> dict[str, Any]: """ Fetch pipeline templates from builtin. :param language: language :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(language, {}) @classmethod - def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None: + def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict[str, Any] | None: """ Fetch pipeline template detail from builtin. :param template_id: Template ID :return: """ - builtin_data: dict[str, dict[str, dict]] = cls._get_builtin_data() + builtin_data: dict[str, dict[str, dict[str, Any]]] = cls._get_builtin_data() return builtin_data.get("pipeline_templates", {}).get(template_id) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 2ee871a266..9d446f6d4b 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,3 +1,5 @@ +from typing import Any, TypedDict + import yaml from sqlalchemy import select @@ -8,25 +10,47 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class CustomizedTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + position: int + chunk_structure: str + + +class CustomizedTemplatesResultDict(TypedDict): + pipeline_templates: list[CustomizedTemplateItemDict] + + +class CustomizedTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + created_by: str + + class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval recommended app from database """ - def get_pipeline_templates(self, language: str) -> dict: + def get_pipeline_templates(self, language: str) -> dict[str, Any]: _, current_tenant_id = current_account_with_tenant() - result = self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) - return result + return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) - def get_pipeline_template_detail(self, template_id: str): - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.CUSTOMIZED @classmethod - def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict: + def fetch_pipeline_templates_from_customized(cls, tenant_id: str, language: str) -> dict[str, Any]: """ Fetch pipeline templates from db. :param tenant_id: tenant id @@ -38,9 +62,9 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) ).all() - recommended_pipelines_results = [] + recommended_pipelines_results: list[CustomizedTemplateItemDict] = [] for pipeline_customized_template in pipeline_customized_templates: - recommended_pipeline_result = { + recommended_pipeline_result: CustomizedTemplateItemDict = { "id": pipeline_customized_template.id, "name": pipeline_customized_template.name, "description": pipeline_customized_template.description, @@ -53,7 +77,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return {"pipeline_templates": recommended_pipelines_results} @classmethod - def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None: + def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict[str, Any] | None: """ Fetch pipeline template detail from db. :param template_id: Template ID diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 43b21a7b32..2964537c35 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,3 +1,5 @@ +from typing import Any, TypedDict + import yaml from sqlalchemy import select @@ -7,24 +9,47 @@ from services.rag_pipeline.pipeline_template.pipeline_template_base import Pipel from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType +class PipelineTemplateItemDict(TypedDict): + id: str + name: str + description: str + icon: dict[str, Any] + copyright: str + privacy_policy: str + position: int + chunk_structure: str + + +class PipelineTemplatesResultDict(TypedDict): + pipeline_templates: list[PipelineTemplateItemDict] + + +class PipelineTemplateDetailDict(TypedDict): + id: str + name: str + icon_info: dict[str, Any] + description: str + chunk_structure: str + export_data: str + graph: dict[str, Any] + + class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ Retrieval pipeline template from database """ - def get_pipeline_templates(self, language: str) -> dict: - result = self.fetch_pipeline_templates_from_db(language) - return result + def get_pipeline_templates(self, language: str) -> dict[str, Any]: + return self.fetch_pipeline_templates_from_db(language) - def get_pipeline_template_detail(self, template_id: str): - result = self.fetch_pipeline_template_detail_from_db(template_id) - return result + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: + return self.fetch_pipeline_template_detail_from_db(template_id) def get_type(self) -> str: return PipelineTemplateType.DATABASE @classmethod - def fetch_pipeline_templates_from_db(cls, language: str) -> dict: + def fetch_pipeline_templates_from_db(cls, language: str) -> dict[str, Any]: """ Fetch pipeline templates from db. :param language: language @@ -37,9 +62,9 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): ).all() ) - recommended_pipelines_results = [] + recommended_pipelines_results: list[PipelineTemplateItemDict] = [] for pipeline_built_in_template in pipeline_built_in_templates: - recommended_pipeline_result = { + recommended_pipeline_result: PipelineTemplateItemDict = { "id": pipeline_built_in_template.id, "name": pipeline_built_in_template.name, "description": pipeline_built_in_template.description, @@ -54,7 +79,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return {"pipeline_templates": recommended_pipelines_results} @classmethod - def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None: + def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict[str, Any] | None: """ Fetch pipeline template detail from db. :param pipeline_id: Pipeline ID diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py index 21c30a4986..0ed2a4b8f2 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py @@ -1,15 +1,16 @@ from abc import ABC, abstractmethod +from typing import Any class PipelineTemplateRetrievalBase(ABC): """Interface for pipeline template retrieval.""" @abstractmethod - def get_pipeline_templates(self, language: str) -> dict: + def get_pipeline_templates(self, language: str) -> dict[str, Any]: raise NotImplementedError @abstractmethod - def get_pipeline_template_detail(self, template_id: str) -> dict | None: + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: raise NotImplementedError @abstractmethod diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index f996db11dc..9565ac46cc 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -1,4 +1,5 @@ import logging +from typing import Any import httpx @@ -15,28 +16,25 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval recommended app from dify official """ - def get_pipeline_template_detail(self, template_id: str) -> dict | None: - result: dict | None + def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: try: - result = self.fetch_pipeline_template_detail_from_dify_official(template_id) + return self.fetch_pipeline_template_detail_from_dify_official(template_id) except Exception as e: logger.warning("fetch recommended app detail from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) - def get_pipeline_templates(self, language: str) -> dict: + def get_pipeline_templates(self, language: str) -> dict[str, Any]: try: - result = self.fetch_pipeline_templates_from_dify_official(language) + return self.fetch_pipeline_templates_from_dify_official(language) except Exception as e: logger.warning("fetch pipeline templates from dify official failed: %r, switch to database.", e) - result = DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) - return result + return DatabasePipelineTemplateRetrieval.fetch_pipeline_templates_from_db(language) def get_type(self) -> str: return PipelineTemplateType.REMOTE @classmethod - def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict: + def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict[str, Any]: """ Fetch pipeline template detail from dify official. @@ -53,11 +51,11 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): + f" status_code: {response.status_code}," + f" response: {response.text[:1000]}" ) - data: dict = response.json() + data: dict[str, Any] = response.json() return data @classmethod - def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict: + def fetch_pipeline_templates_from_dify_official(cls, language: str) -> dict[str, Any]: """ Fetch pipeline templates from dify official. :param language: language @@ -69,6 +67,6 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): if response.status_code != 200: raise ValueError(f"fetch pipeline templates failed, status code: {response.status_code}") - result: dict = response.json() + result: dict[str, Any] = response.json() return result diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index f6d80f9a6e..9db6682e10 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,19 +5,10 @@ import threading import time from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Union, cast +from typing import Any, cast from uuid import uuid4 from flask_login import current_user -from graphon.entities import WorkflowNodeExecution -from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus -from graphon.errors import WorkflowNodeRunFailedError -from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from graphon.node_events import NodeRunResult -from graphon.nodes.base.node import Node -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from graphon.runtime import VariablePool -from graphon.variables.variables import Variable, VariableBase from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -53,6 +44,15 @@ from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account from models.dataset import ( # type: ignore @@ -104,7 +104,7 @@ class RagPipelineService: self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) @classmethod - def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict: + def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict[str, Any]: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() @@ -120,7 +120,7 @@ class RagPipelineService: return result @classmethod - def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None: + def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict[str, Any] | None: """ Get pipeline template detail. @@ -131,7 +131,7 @@ class RagPipelineService: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + built_in_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id) if built_in_result is None: logger.warning( "pipeline template retrieval returned empty result, template_id: %s, mode: %s", @@ -142,7 +142,7 @@ class RagPipelineService: else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + customized_result: dict[str, Any] | None = retrieval_instance.get_pipeline_template_detail(template_id) return customized_result @classmethod @@ -297,7 +297,7 @@ class RagPipelineService: self, *, pipeline: Pipeline, - graph: dict, + graph: dict[str, Any], unique_hash: str | None, account: Account, environment_variables: Sequence[VariableBase], @@ -467,14 +467,16 @@ class RagPipelineService: return default_block_configs - def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None: + def get_default_block_config( + self, node_type: str, filters: dict[str, Any] | None = None + ) -> Mapping[str, object] | None: """ Get default config of node. :param node_type: node type :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config @@ -500,7 +502,7 @@ class RagPipelineService: return default_config def run_draft_workflow_node( - self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account + self, pipeline: Pipeline, node_id: str, user_inputs: dict[str, Any], account: Account ) -> WorkflowNodeExecutionModel | None: """ Run draft workflow node @@ -582,7 +584,7 @@ class RagPipelineService: self, pipeline: Pipeline, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], account: Account, datasource_type: str, is_published: bool, @@ -749,7 +751,7 @@ class RagPipelineService: self, pipeline: Pipeline, node_id: str, - user_inputs: dict, + user_inputs: dict[str, Any], account: Account, datasource_type: str, is_published: bool, @@ -979,7 +981,7 @@ class RagPipelineService: return workflow_node_execution def update_workflow( - self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict + self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any] ) -> Workflow | None: """ Update workflow attributes @@ -1099,7 +1101,9 @@ class RagPipelineService: ] return datasource_provider_variables - def get_rag_pipeline_paginate_workflow_runs(self, pipeline: Pipeline, args: dict) -> InfiniteScrollPagination: + def get_rag_pipeline_paginate_workflow_runs( + self, pipeline: Pipeline, args: dict[str, Any] + ) -> InfiniteScrollPagination: """ Get debug workflow run list Only return triggered_from == debugging @@ -1169,7 +1173,7 @@ class RagPipelineService: return list(node_executions) @classmethod - def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict): + def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict[str, Any]): """ Publish customized pipeline template """ @@ -1259,7 +1263,7 @@ class RagPipelineService: ) return node_exec - def set_datasource_variables(self, pipeline: Pipeline, args: dict, current_user: Account): + def set_datasource_variables(self, pipeline: Pipeline, args: dict[str, Any], current_user: Account): """ Set datasource variables """ @@ -1346,7 +1350,7 @@ class RagPipelineService: ) return workflow_node_execution_db_model - def get_recommended_plugins(self, type: str) -> dict: + def get_recommended_plugins(self, type: str) -> dict[str, Any]: # Query active recommended plugins stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) if type and type != "all": @@ -1387,7 +1391,7 @@ class RagPipelineService: "uninstalled_recommended_plugins": uninstalled_plugin_list, } - def retry_error_document(self, dataset: Dataset, document: Document, user: Union[Account, EndUser]): + def retry_error_document(self, dataset: Dataset, document: Document, user: Account | EndUser): """ Retry error document """ diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index c24bf3d649..f315d053cb 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -5,7 +5,7 @@ import logging import uuid from collections.abc import Mapping from datetime import UTC, datetime -from typing import cast +from typing import Any, cast from urllib.parse import urlparse from uuid import uuid4 @@ -13,12 +13,6 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user -from graphon.enums import BuiltinNodeTypes -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes.llm.entities import LLMNodeData -from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData -from graphon.nodes.tool.entities import ToolNodeData from packaging import version from pydantic import BaseModel from sqlalchemy import select @@ -33,6 +27,12 @@ from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_redis import redis_client from factories import variable_factory +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline from models.enums import CollectionBindingType, DatasetRuntimeMode @@ -283,7 +283,9 @@ class RagPipelineDslService: ): raise ValueError("Chunk structure is not compatible with the published pipeline") if not dataset: - datasets = self._session.query(Dataset).filter_by(tenant_id=account.current_tenant_id).all() + datasets = self._session.scalars( + select(Dataset).where(Dataset.tenant_id == account.current_tenant_id) + ).all() names = [dataset.name for dataset in datasets] generate_name = generate_incremental_name(names, name) dataset = Dataset( @@ -303,8 +305,8 @@ class RagPipelineDslService: chunk_structure=knowledge_configuration.chunk_structure, ) if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - dataset_collection_binding = ( - self._session.query(DatasetCollectionBinding) + dataset_collection_binding = self._session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, @@ -312,7 +314,7 @@ class RagPipelineDslService: DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: @@ -440,8 +442,8 @@ class RagPipelineDslService: dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - dataset_collection_binding = ( - self._session.query(DatasetCollectionBinding) + dataset_collection_binding = self._session.scalar( + select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, @@ -449,7 +451,7 @@ class RagPipelineDslService: DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) - .first() + .limit(1) ) if not dataset_collection_binding: @@ -524,7 +526,7 @@ class RagPipelineDslService: self, *, pipeline: Pipeline | None, - data: dict, + data: dict[str, Any], account: Account, dependencies: list[PluginDependency] | None = None, ) -> Pipeline: @@ -591,14 +593,14 @@ class RagPipelineDslService: IMPORT_INFO_REDIS_EXPIRY, CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(), ) - workflow = ( - self._session.query(Workflow) + workflow = self._session.scalar( + select(Workflow) .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", ) - .first() + .limit(1) ) # create draft workflow if not found @@ -658,21 +660,21 @@ class RagPipelineDslService: return yaml.dump(export_data, allow_unicode=True) # type: ignore - def _append_workflow_export_data(self, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None: + def _append_workflow_export_data( + self, *, export_data: dict[str, Any], pipeline: Pipeline, include_secret: bool + ) -> None: """ Append workflow export data :param export_data: export data :param pipeline: Pipeline instance """ - workflow = ( - self._session.query(Workflow) - .where( + workflow = self._session.scalar( + select(Workflow).where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", ) - .first() ) if not workflow: raise ValueError("Missing draft workflow configuration, please check.") @@ -904,15 +906,16 @@ class RagPipelineDslService: ): if rag_pipeline_dataset_create_entity.name: # check if dataset name already exists - if ( - self._session.query(Dataset) - .filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id) - .first() + if self._session.scalar( + select(Dataset).where( + Dataset.name == rag_pipeline_dataset_create_entity.name, + Dataset.tenant_id == tenant_id, + ) ): raise ValueError(f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists.") else: # generate a random name as Untitled 1 2 3 ... - datasets = self._session.query(Dataset).filter_by(tenant_id=tenant_id).all() + datasets = self._session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all() names = [dataset.name for dataset in datasets] rag_pipeline_dataset_create_entity.name = generate_incremental_name( names, diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index c3b00fe109..f08ec7474b 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import UTC, datetime from pathlib import Path +from typing import Any from uuid import uuid4 import yaml @@ -154,7 +155,7 @@ class RagPipelineTransformService: raise ValueError("Unsupported doc form") return pipeline_yaml - def _deal_file_extensions(self, node: dict): + def _deal_file_extensions(self, node: dict[str, Any]): file_extensions = node.get("data", {}).get("fileExtensions", []) if not file_extensions: return node @@ -167,7 +168,7 @@ class RagPipelineTransformService: dataset: Dataset, indexing_technique: str | None, retrieval_model: RetrievalSetting | None, - node: dict, + node: dict[str, Any], ): knowledge_configuration_dict = node.get("data", {}) @@ -191,7 +192,7 @@ class RagPipelineTransformService: def _create_pipeline( self, - data: dict, + data: dict[str, Any], ) -> Pipeline: """Create a new app or update an existing one.""" pipeline_data = data.get("rag_pipeline", {}) @@ -258,7 +259,7 @@ class RagPipelineTransformService: db.session.add(pipeline) return pipeline - def _deal_dependencies(self, pipeline_yaml: dict, tenant_id: str): + def _deal_dependencies(self, pipeline_yaml: dict[str, Any], tenant_id: str): installer_manager = PluginInstaller() installed_plugins = installer_manager.list_plugins(tenant_id) diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index 64751d186c..16dc66cd76 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -1,6 +1,7 @@ import json from os import path from pathlib import Path +from typing import Any from flask import current_app @@ -13,7 +14,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): Retrieval recommended app from buildin, the location is constants/recommended_apps.json """ - builtin_data: dict | None = None + builtin_data: dict[str, Any] | None = None def get_type(self) -> str: return RecommendAppType.BUILDIN @@ -53,7 +54,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): return builtin_data.get("recommended_apps", {}).get(language, {}) @classmethod - def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None: + def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict[str, Any] | None: """ Fetch recommended app detail from builtin. :param app_id: App ID diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index b217c9026a..5818be0480 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -1,4 +1,5 @@ import logging +from typing import Any import httpx @@ -35,7 +36,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): return RecommendAppType.REMOTE @classmethod - def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None: + def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict[str, Any] | None: """ Fetch recommended app detail from dify official. :param app_id: App ID @@ -46,7 +47,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: return None - data: dict = response.json() + data: dict[str, Any] = response.json() return data @classmethod @@ -62,7 +63,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") - result: dict = response.json() + result: dict[str, Any] = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 9819822103..134dd37a3e 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,3 +1,5 @@ +from typing import Any + from sqlalchemy import select from configs import dify_config @@ -37,7 +39,7 @@ class RecommendedAppService: return result @classmethod - def get_recommend_app_detail(cls, app_id: str) -> dict | None: + def get_recommend_app_detail(cls, app_id: str) -> dict[str, Any] | None: """ Get recommend app detail. :param app_id: app id @@ -45,7 +47,7 @@ class RecommendedAppService: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() - result: dict = retrieval_instance.get_recommend_app_detail(app_id) + result: dict[str, Any] = retrieval_instance.get_recommend_app_detail(app_id) if FeatureService.get_system_features().enable_trial_app: app_id = result["id"] trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1)) diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index ab60986bfe..21be411bea 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -27,13 +27,13 @@ from dataclasses import dataclass, field from typing import Any, TypedDict import click -from graphon.enums import WorkflowType from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from enums.cloud_plan import CloudPlan from extensions.ext_database import db +from graphon.enums import WorkflowType from libs.archive_storage import ( ArchiveStorage, ArchiveStorageNotConfiguredError, diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 8760d60de0..cf39469be8 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -6,8 +6,7 @@ import uuid from datetime import UTC, datetime from typing import TypedDict, cast -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelType +from sqlalchemy import select from sqlalchemy.orm import Session from core.db.session_factory import session_factory @@ -17,6 +16,8 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument @@ -109,8 +110,13 @@ class SummaryIndexService: """ with session_factory.create_session() as session: # Check if summary record already exists - existing_summary = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + existing_summary = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if existing_summary: @@ -309,8 +315,10 @@ class SummaryIndexService: summary_record_id, segment.id, ) - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where(DocumentSegmentSummary.id == summary_record_id) + .limit(1) ) if not summary_record_in_session: @@ -323,10 +331,13 @@ class SummaryIndexService: dataset.id, segment.id, ) - summary_record_in_session = ( - session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if not summary_record_in_session: @@ -338,7 +349,6 @@ class SummaryIndexService: summary_record_id, ) summary_record_in_session = DocumentSegmentSummary( - id=summary_record_id, # Use the same ID if available dataset_id=dataset.id, document_id=segment.document_id, chunk_id=segment.id, @@ -349,6 +359,9 @@ class SummaryIndexService: status=SummaryStatus.COMPLETED, enabled=True, ) + if summary_record_in_session is None: + raise RuntimeError("summary_record_in_session should not be None at this point") + summary_record_in_session.id = summary_record_id session.add(summary_record_in_session) logger.info( "Created new summary record (id=%s) for segment %s after vectorization", @@ -487,8 +500,10 @@ class SummaryIndexService: with session_factory.create_session() as error_session: # Try to find the record by id first # Note: Using assignment only (no type annotation) to avoid redeclaration error - summary_record_in_session = ( - error_session.query(DocumentSegmentSummary).filter_by(id=summary_record_id).first() + summary_record_in_session = error_session.scalar( + select(DocumentSegmentSummary) + .where(DocumentSegmentSummary.id == summary_record_id) + .limit(1) ) if not summary_record_in_session: # Try to find by chunk_id and dataset_id @@ -500,10 +515,13 @@ class SummaryIndexService: dataset.id, segment.id, ) - summary_record_in_session = ( - error_session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record_in_session = error_session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record_in_session: @@ -551,14 +569,12 @@ class SummaryIndexService: with session_factory.create_session() as session: # Query existing summary records - existing_summaries = ( - session.query(DocumentSegmentSummary) - .filter( + existing_summaries = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset.id, ) - .all() - ) + ).all() existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries} # Create or update records @@ -603,8 +619,13 @@ class SummaryIndexService: error: Error message """ with session_factory.create_session() as session: - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -639,8 +660,13 @@ class SummaryIndexService: with session_factory.create_session() as session: try: # Get or refresh summary record in this session - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if not summary_record_in_session: @@ -710,8 +736,13 @@ class SummaryIndexService: except Exception as e: logger.exception("Failed to generate summary for segment %s", segment.id) # Update summary record with error status - summary_record_in_session = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record_in_session = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record_in_session: summary_record_in_session.status = SummaryStatus.ERROR @@ -769,17 +800,17 @@ class SummaryIndexService: with session_factory.create_session() as session: # Query segments (only enabled segments) - query = session.query(DocumentSegment).filter_by( - dataset_id=dataset.id, - document_id=document.id, - status="completed", - enabled=True, # Only generate summaries for enabled segments + stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled.is_(True), # Only generate summaries for enabled segments ) if segment_ids: - query = query.filter(DocumentSegment.id.in_(segment_ids)) + stmt = stmt.where(DocumentSegment.id.in_(segment_ids)) - segments = query.all() + segments = list(session.scalars(stmt).all()) if not segments: logger.info("No segments found for document %s", document.id) @@ -848,15 +879,15 @@ class SummaryIndexService: from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by( - dataset_id=dataset.id, - enabled=True, # Only disable enabled summaries + stmt = select(DocumentSegmentSummary).where( + DocumentSegmentSummary.dataset_id == dataset.id, + DocumentSegmentSummary.enabled.is_(True), # Only disable enabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -911,15 +942,15 @@ class SummaryIndexService: return with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by( - dataset_id=dataset.id, - enabled=False, # Only enable disabled summaries + stmt = select(DocumentSegmentSummary).where( + DocumentSegmentSummary.dataset_id == dataset.id, + DocumentSegmentSummary.enabled.is_(False), # Only enable disabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -935,13 +966,13 @@ class SummaryIndexService: enabled_count = 0 for summary in summaries: # Get the original segment - segment = ( - session.query(DocumentSegment) - .filter_by( - id=summary.chunk_id, - dataset_id=dataset.id, + segment = session.scalar( + select(DocumentSegment) + .where( + DocumentSegment.id == summary.chunk_id, + DocumentSegment.dataset_id == dataset.id, ) - .first() + .limit(1) ) # Summary.enabled stays in sync with chunk.enabled, @@ -988,12 +1019,12 @@ class SummaryIndexService: segment_ids: List of segment IDs to delete summaries for. If None, delete all. """ with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id) + stmt = select(DocumentSegmentSummary).where(DocumentSegmentSummary.dataset_id == dataset.id) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - summaries = query.all() + summaries = session.scalars(stmt).all() if not summaries: return @@ -1046,10 +1077,13 @@ class SummaryIndexService: # Check if summary_content is empty (whitespace-only strings are considered empty) if not summary_content or not summary_content.strip(): # If summary is empty, only delete existing summary vector and record - summary_record = ( - session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id) - .first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -1077,8 +1111,13 @@ class SummaryIndexService: return None # Find existing summary record - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: @@ -1162,8 +1201,13 @@ class SummaryIndexService: except Exception as e: logger.exception("Failed to update summary for segment %s", segment.id) # Update summary record with error status if it exists - summary_record = ( - session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset.id, + ) + .limit(1) ) if summary_record: summary_record.status = SummaryStatus.ERROR @@ -1185,14 +1229,14 @@ class SummaryIndexService: DocumentSegmentSummary instance if found, None otherwise """ with session_factory.create_session() as session: - return ( - session.query(DocumentSegmentSummary) + return session.scalar( + select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment_id, DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) - .first() + .limit(1) ) @staticmethod @@ -1211,15 +1255,13 @@ class SummaryIndexService: return {} with session_factory.create_session() as session: - summary_records = ( - session.query(DocumentSegmentSummary) - .where( + summary_records = session.scalars( + select(DocumentSegmentSummary).where( DocumentSegmentSummary.chunk_id.in_(segment_ids), DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) - .all() - ) + ).all() return {summary.chunk_id: summary for summary in summary_records} @@ -1239,16 +1281,16 @@ class SummaryIndexService: List of DocumentSegmentSummary instances (only enabled summaries) """ with session_factory.create_session() as session: - query = session.query(DocumentSegmentSummary).filter( + stmt = select(DocumentSegmentSummary).where( DocumentSegmentSummary.document_id == document_id, DocumentSegmentSummary.dataset_id == dataset_id, - DocumentSegmentSummary.enabled == True, # Only return enabled summaries + DocumentSegmentSummary.enabled.is_(True), # Only return enabled summaries ) if segment_ids: - query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids)) + stmt = stmt.where(DocumentSegmentSummary.chunk_id.in_(segment_ids)) - return query.all() + return list(session.scalars(stmt).all()) @staticmethod def get_document_summary_index_status(document_id: str, dataset_id: str, tenant_id: str) -> str | None: @@ -1265,16 +1307,15 @@ class SummaryIndexService: """ # Get all segments for this document (excluding qa_model and re_segment) with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment.id) - .where( - DocumentSegment.document_id == document_id, - DocumentSegment.status != "re_segment", - DocumentSegment.tenant_id == tenant_id, - ) - .all() + segment_ids = list( + session.scalars( + select(DocumentSegment.id).where( + DocumentSegment.document_id == document_id, + DocumentSegment.status != "re_segment", + DocumentSegment.tenant_id == tenant_id, + ) + ).all() ) - segment_ids = [seg.id for seg in segments] if not segment_ids: return None @@ -1312,15 +1353,13 @@ class SummaryIndexService: # Get all segments for these documents (excluding qa_model and re_segment) with session_factory.create_session() as session: - segments = ( - session.query(DocumentSegment.id, DocumentSegment.document_id) - .where( + segments = session.execute( + select(DocumentSegment.id, DocumentSegment.document_id).where( DocumentSegment.document_id.in_(document_ids), DocumentSegment.status != "re_segment", DocumentSegment.tenant_id == tenant_id, ) - .all() - ) + ).all() # Group segments by document_id document_segments_map: dict[str, list[str]] = {} diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 1882c855ea..8043a99be1 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,9 +1,11 @@ import uuid +from typing import cast import sqlalchemy as sa from flask_login import current_user from pydantic import BaseModel, Field -from sqlalchemy import func, select +from sqlalchemy import delete, func, select +from sqlalchemy.engine import CursorResult from werkzeug.exceptions import NotFound from extensions.ext_database import db @@ -29,7 +31,7 @@ class TagBindingCreatePayload(BaseModel): class TagBindingDeletePayload(BaseModel): - tag_id: str + tag_ids: list[str] = Field(min_length=1) target_id: str type: TagType @@ -178,13 +180,18 @@ class TagService: @staticmethod def delete_tag_binding(payload: TagBindingDeletePayload): TagService.check_target_exists(payload.type, payload.target_id) - tag_binding = db.session.scalar( - select(TagBinding) - .where(TagBinding.target_id == payload.target_id, TagBinding.tag_id == payload.tag_id) - .limit(1) + result = cast( + CursorResult, + db.session.execute( + delete(TagBinding).where( + TagBinding.target_id == payload.target_id, + TagBinding.tag_id.in_(payload.tag_ids), + TagBinding.tenant_id == current_user.current_tenant_id, + ) + ), ) - if tag_binding: - db.session.delete(tag_binding) + + if result.rowcount: db.session.commit() @staticmethod diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index dfc0c2c63f..5ff2c21749 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -2,9 +2,9 @@ import json import logging from typing import Any, TypedDict, cast -from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx import get from sqlalchemy import select +from sqlalchemy.orm import sessionmaker from core.entities.provider_entities import ProviderConfig from core.tools.__base.tool_runtime import ToolRuntime @@ -16,11 +16,13 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, ) +from core.tools.errors import ApiToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -92,7 +94,7 @@ class ApiToolManageService: @staticmethod def convert_schema_to_tool_bundles( - schema: str, extra_info: dict | None = None + schema: str, extra_info: dict[str, Any] | None = None ) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]: """ convert schema to tool bundles @@ -109,78 +111,92 @@ class ApiToolManageService: user_id: str, tenant_id: str, provider_name: str, - icon: dict, - credentials: dict, + icon: dict[str, Any], + credentials: dict[str, Any], schema_type: ApiProviderSchemaType, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str], - ): + ) -> dict[str, Any]: """ - create api tool provider + Create a new API tool provider. + + :param user_id: The ID of the user creating the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :param icon: The icon configuration for the provider. + :param credentials: The credentials for the provider. + :param schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :param privacy_policy: The privacy policy URL or text. + :param custom_disclaimer: Custom disclaimer text. + :param labels: A list of labels for the provider. + :return: A dictionary indicating the result status. """ + provider_name = provider_name.strip() # check if the provider exists - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + # Create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if provider is not None: - raise ValueError(f"provider {provider_name} already exists") + if provider is not None: + raise ValueError(f"provider {provider_name} already exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if len(tool_bundles) > 100: - raise ValueError("the number of apis should be less than 100") + if len(tool_bundles) > 100: + raise ValueError("the number of apis should be less than 100") - # create db provider - db_provider = ApiToolProvider( - tenant_id=tenant_id, - user_id=user_id, - name=provider_name, - icon=json.dumps(icon), - schema=schema, - description=extra_info.get("description", ""), - schema_type_str=schema_type, - tools_str=json.dumps(jsonable_encoder(tool_bundles)), - credentials_str="{}", - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - ) + # create API tool provider + api_tool_provider = ApiToolProvider( + tenant_id=tenant_id, + user_id=user_id, + name=provider_name, + icon=json.dumps(icon), + schema=schema, + description=extra_info.get("description", ""), + schema_type_str=schema_type, + tools_str=json.dumps(jsonable_encoder(tool_bundles)), + credentials_str="{}", + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, + ) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # create provider entity - provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # create provider entity + provider_controller = ApiToolProviderController.from_db(api_tool_provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - # encrypt credentials - encrypter, _ = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) - db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) + # encrypt credentials + encrypter, _ = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) + api_tool_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) - db.session.add(db_provider) - db.session.commit() + _session.add(api_tool_provider) - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels) + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels, _session) return {"result": "success"} @@ -212,16 +228,25 @@ class ApiToolManageService: @staticmethod def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]: """ - list api tool provider tools + List tools provided by a specific API tool provider. + + :param user_id: The ID of the user requesting the list. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :return: A list of ToolApiEntity objects. """ - provider: ApiToolProvider | None = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + + # create new session with automatic transaction management + provider: ApiToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) if provider is None: raise ValueError(f"you have not added provider {provider_name}") @@ -244,110 +269,140 @@ class ApiToolManageService: tenant_id: str, provider_name: str, original_provider: str, - icon: dict, - credentials: dict, + icon: dict[str, Any], + credentials: dict[str, Any], _schema_type: ApiProviderSchemaType, schema: str, privacy_policy: str | None, custom_disclaimer: str, labels: list[str], - ): + ) -> dict[str, Any]: """ - update api tool provider + Update an existing API tool provider. + + :param user_id: The ID of the user updating the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The new name of the API tool provider. + :param original_provider: The original name of the API tool provider. + :param icon: The icon configuration for the provider. + :param credentials: The credentials for the provider. + :param _schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :param privacy_policy: The privacy policy URL or text. + :param custom_disclaimer: Custom disclaimer text. + :param labels: A list of labels for the provider. + :return: A dictionary indicating the result status. """ + provider_name = provider_name.strip() # check if the provider exists - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, + # create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .limit(1) ) - .limit(1) - ) - if provider is None: - raise ValueError(f"api provider {provider_name} does not exists") - # parse openapi to tool bundle - extra_info: dict[str, str] = {} - # extra info like description will be set here - tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) + if provider is None: + raise ApiToolProviderNotFoundError(provider_name=original_provider, tenant_id=tenant_id) - # update db provider - provider.name = provider_name - provider.icon = json.dumps(icon) - provider.schema = schema - provider.description = extra_info.get("description", "") - provider.schema_type_str = schema_type - provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) - provider.privacy_policy = privacy_policy - provider.custom_disclaimer = custom_disclaimer + # parse openapi to tool bundle + extra_info: dict[str, str] = {} + # extra info like description will be set here + tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - if "auth_type" not in credentials: - raise ValueError("auth_type is required") + # update db provider + provider.name = provider_name + provider.icon = json.dumps(icon) + provider.schema = schema + provider.description = extra_info.get("description", "") + provider.schema_type_str = schema_type + provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) + provider.privacy_policy = privacy_policy + provider.custom_disclaimer = custom_disclaimer - # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) + if "auth_type" not in credentials: + raise ValueError("auth_type is required") - # create provider entity - provider_controller = ApiToolProviderController.from_db(provider, auth_type) - # load tools into provider entity - provider_controller.load_bundled_tools(tool_bundles) + # get auth type, none or api key + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) - # get original credentials if exists - encrypter, cache = create_tool_provider_encrypter( - tenant_id=tenant_id, - controller=provider_controller, - ) + # create provider entity + provider_controller = ApiToolProviderController.from_db(provider, auth_type) + # load tools into provider entity + provider_controller.load_bundled_tools(tool_bundles) - original_credentials = encrypter.decrypt(provider.credentials) - masked_credentials = encrypter.mask_plugin_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] + # get original credentials if exists + encrypter, cache = create_tool_provider_encrypter( + tenant_id=tenant_id, + controller=provider_controller, + ) - credentials = dict(encrypter.encrypt(credentials)) - provider.credentials_str = json.dumps(credentials) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_plugin_credentials(original_credentials) - db.session.add(provider) - db.session.commit() + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: + credentials[name] = original_credentials[name] + + credentials = dict(encrypter.encrypt(credentials)) + provider.credentials_str = json.dumps(credentials) + + _session.add(provider) + + # update labels + ToolLabelManager.update_tool_labels(provider_controller, labels, _session) # delete cache cache.delete() - # update labels - ToolLabelManager.update_tool_labels(provider_controller, labels) - return {"result": "success"} @staticmethod def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + Delete an API tool provider. + + :param user_id: The ID of the user performing the deletion operation. + :param tenant_id: The ID of the workspace/tenant where the provider belongs. + :param provider_name: The unique name of the API tool provider to be deleted. + :raises ValueError: If the specified provider does not exist in the tenant. + :return: A dictionary indicating the result status. """ - provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + + # create new session with automatic transaction management + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider: ApiToolProvider | None = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if provider is None: - raise ValueError(f"you have not added provider {provider_name}") + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") - db.session.delete(provider) - db.session.commit() + _session.delete(provider) return {"result": "success"} @staticmethod - def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str) -> dict[str, Any]: """ - get api tool provider + Get API tool provider details. + + :param user_id: The ID of the user requesting the provider. + :param tenant_id: The ID of the workspace/tenant. + :param provider: The name of the API tool provider. + :return: A dictionary containing the provider details. """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) @@ -356,14 +411,24 @@ class ApiToolManageService: tenant_id: str, provider_name: str, tool_name: str, - credentials: dict, - parameters: dict, + credentials: dict[str, Any], + parameters: dict[str, Any], schema_type: ApiProviderSchemaType, schema: str, - ): + ) -> dict[str, Any]: """ - test api tool before adding api tool provider + Test an API tool before adding the API tool provider. + + :param tenant_id: The ID of the workspace/tenant. + :param provider_name: The name of the API tool provider. + :param tool_name: The name of the specific tool to test. + :param credentials: The credentials for the provider. + :param parameters: The parameters to pass to the tool. + :param schema_type: The type of schema (e.g., OpenAPI). + :param schema: The raw schema string. + :return: A dictionary containing the result or error message. """ + if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema_type}") @@ -377,18 +442,21 @@ class ApiToolManageService: if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - db_provider = db.session.scalar( - select(ApiToolProvider) - .where( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, + # create new session with automatic transaction management to get the provider + provider: ApiToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(ApiToolProvider) + .where( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .limit(1) ) - .limit(1) - ) - if not db_provider: + if provider is None: # create a fake db provider - db_provider = ApiToolProvider( + provider = ApiToolProvider( tenant_id="", user_id="", name="", @@ -407,12 +475,12 @@ class ApiToolManageService: auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity - provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) + provider_controller = ApiToolProviderController.from_db(provider, auth_type) # load tools into provider entity provider_controller.load_bundled_tools(tool_bundles) # decrypt credentials - if db_provider.id: + if provider.id: encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, controller=provider_controller, @@ -443,14 +511,21 @@ class ApiToolManageService: @staticmethod def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ - list api tools + List all API tools for a specific tenant. + + :param tenant_id: The ID of the workspace/tenant. + :return: A list of ToolProviderApiEntity objects. """ # get all api providers - db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() + # create new session with automatic transaction management + providers: list[ApiToolProvider] = [] + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + providers = list( + _session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all() + ) result: list[ToolProviderApiEntity] = [] - - for provider in db_providers: + for provider in providers: # convert provider controller to user provider provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 3daaf9a263..b8242ab3a5 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from pathlib import Path from typing import Any -from sqlalchemy import exists, select +from sqlalchemy import delete, exists, func, select, update from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -26,7 +26,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_provider_encrypter -from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.tools.utils.system_encryption import decrypt_system_params from extensions.ext_database import db from extensions.ext_redis import redis_client from models.provider_ids import ToolProviderID @@ -47,11 +47,15 @@ class BuiltinToolManageService: """ tool_provider = ToolProviderID(provider) with sessionmaker(bind=db.engine).begin() as session: - session.query(ToolOAuthTenantClient).filter_by( - tenant_id=tenant_id, - provider=tool_provider.provider_name, - plugin_id=tool_provider.plugin_id, - ).delete() + session.execute( + delete(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ) + .execution_options(synchronize_session=False) + ) return {"result": "success"} @staticmethod @@ -143,7 +147,7 @@ class BuiltinToolManageService: tenant_id: str, provider: str, credential_id: str, - credentials: dict | None = None, + credentials: dict[str, Any] | None = None, name: str | None = None, ): """ @@ -151,13 +155,13 @@ class BuiltinToolManageService: """ with sessionmaker(bind=db.engine).begin() as session: # get if the provider exists - db_provider = ( - session.query(BuiltinToolProvider) + db_provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.id == credential_id, ) - .first() + .limit(1) ) if db_provider is None: raise ValueError(f"you have not added provider {provider}") @@ -173,7 +177,7 @@ class BuiltinToolManageService: ) original_credentials = encrypter.decrypt(db_provider.credentials) - new_credentials: dict = { + new_credentials: dict[str, Any] = { key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE) for key, value in credentials.items() } @@ -212,7 +216,7 @@ class BuiltinToolManageService: api_type: CredentialType, tenant_id: str, provider: str, - credentials: dict, + credentials: dict[str, Any], expires_at: int = -1, name: str | None = None, ): @@ -228,7 +232,13 @@ class BuiltinToolManageService: raise ValueError(f"provider {provider} does not need credentials") provider_count = ( - session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() + session.scalar( + select(func.count(BuiltinToolProvider.id)).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + ) + or 0 ) # check if the provider count is reached the limit @@ -304,16 +314,15 @@ class BuiltinToolManageService: def generate_builtin_tool_provider_name( session: Session, tenant_id: str, provider: str, credential_type: CredentialType ) -> str: - db_providers = ( - session.query(BuiltinToolProvider) - .filter_by( - tenant_id=tenant_id, - provider=provider, - credential_type=credential_type, + db_providers = session.scalars( + select(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + BuiltinToolProvider.credential_type == credential_type, ) .order_by(BuiltinToolProvider.created_at.desc()) - .all() - ) + ).all() return generate_incremental_name( [provider.name for provider in db_providers], f"{credential_type.get_name()}", @@ -375,13 +384,13 @@ class BuiltinToolManageService: delete tool provider """ with sessionmaker(bind=db.engine).begin() as session: - db_provider = ( - session.query(BuiltinToolProvider) + db_provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.id == credential_id, ) - .first() + .limit(1) ) if db_provider is None: @@ -405,14 +414,26 @@ class BuiltinToolManageService: """ with sessionmaker(bind=db.engine).begin() as session: # get provider - target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first() + target_provider = session.scalar( + select(BuiltinToolProvider) + .where(BuiltinToolProvider.id == id, BuiltinToolProvider.tenant_id == tenant_id) + .limit(1) + ) if target_provider is None: raise ValueError("provider not found") # clear default provider - session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True - ).update({"is_default": False}) + session.execute( + update(BuiltinToolProvider) + .where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.user_id == user_id, + BuiltinToolProvider.provider == provider, + BuiltinToolProvider.is_default.is_(True), + ) + .values(is_default=False) + .execution_options(synchronize_session=False) + ) # set new default provider target_provider.is_default = True @@ -426,10 +447,13 @@ class BuiltinToolManageService: """ tool_provider = ToolProviderID(provider_name) with Session(db.engine, autoflush=False) as session: - system_client: ToolOAuthSystemClient | None = ( - session.query(ToolOAuthSystemClient) - .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) - .first() + system_client = session.scalar( + select(ToolOAuthSystemClient) + .where( + ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id, + ToolOAuthSystemClient.provider == tool_provider.provider_name, + ) + .limit(1) ) return system_client is not None @@ -440,15 +464,15 @@ class BuiltinToolManageService: """ tool_provider = ToolProviderID(provider) with Session(db.engine, autoflush=False) as session: - user_client: ToolOAuthTenantClient | None = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - provider=tool_provider.provider_name, - plugin_id=tool_provider.plugin_id, - enabled=True, + user_client = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) return user_client is not None and user_client.enabled @@ -465,15 +489,15 @@ class BuiltinToolManageService: cache=NoOpProviderCredentialCache(), ) with Session(db.engine, autoflush=False) as session: - user_client: ToolOAuthTenantClient | None = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - provider=tool_provider.provider_name, - plugin_id=tool_provider.plugin_id, - enabled=True, + user_client = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) oauth_params: Mapping[str, Any] | None = None if user_client: @@ -487,14 +511,17 @@ class BuiltinToolManageService: if not is_verified: return oauth_params - system_client: ToolOAuthSystemClient | None = ( - session.query(ToolOAuthSystemClient) - .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) - .first() + system_client = session.scalar( + select(ToolOAuthSystemClient) + .where( + ToolOAuthSystemClient.plugin_id == tool_provider.plugin_id, + ToolOAuthSystemClient.provider == tool_provider.provider_name, + ) + .limit(1) ) if system_client: try: - oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + oauth_params = decrypt_system_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") @@ -582,8 +609,8 @@ class BuiltinToolManageService: provider_name = provider_id_entity.provider_name if provider_id_entity.organization != "langgenius": - provider = ( - session.query(BuiltinToolProvider) + provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == full_provider_name, @@ -592,11 +619,11 @@ class BuiltinToolManageService: BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) - .first() + .limit(1) ) else: - provider = ( - session.query(BuiltinToolProvider) + provider = session.scalar( + select(BuiltinToolProvider) .where( BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_name) @@ -606,7 +633,7 @@ class BuiltinToolManageService: BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) - .first() + .limit(1) ) if provider is None: @@ -616,21 +643,21 @@ class BuiltinToolManageService: return provider except Exception: # it's an old provider without organization - return ( - session.query(BuiltinToolProvider) + return session.scalar( + select(BuiltinToolProvider) .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) .order_by( BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) - .first() + .limit(1) ) @staticmethod def save_custom_oauth_client_params( tenant_id: str, provider: str, - client_params: dict | None = None, + client_params: dict[str, Any] | None = None, enable_oauth_custom_client: bool | None = None, ): """ @@ -648,14 +675,14 @@ class BuiltinToolManageService: raise ValueError(f"Provider {provider} is not a builtin or plugin provider") with sessionmaker(bind=db.engine).begin() as session: - custom_client_params = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=tool_provider.plugin_id, - provider=tool_provider.provider_name, + custom_client_params = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, ) - .first() + .limit(1) ) # if the record does not exist, create a basic record @@ -692,14 +719,14 @@ class BuiltinToolManageService: """ with Session(db.engine) as session: tool_provider = ToolProviderID(provider) - custom_oauth_client_params: ToolOAuthTenantClient | None = ( - session.query(ToolOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=tool_provider.plugin_id, - provider=tool_provider.provider_name, + custom_oauth_client_params = session.scalar( + select(ToolOAuthTenantClient) + .where( + ToolOAuthTenantClient.tenant_id == tenant_id, + ToolOAuthTenantClient.plugin_id == tool_provider.plugin_id, + ToolOAuthTenantClient.provider == tool_provider.provider_name, ) - .first() + .limit(1) ) if custom_oauth_client_params is None: return {} diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 690b06ea7d..89762d6772 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -17,6 +17,7 @@ from core.helper import encrypter from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry +from core.mcp.entities import AuthActionType, AuthResult from core.mcp.error import MCPAuthError, MCPError from core.mcp.types import Tool as MCPTool from core.tools.entities.api_entities import ToolProviderApiEntity @@ -496,7 +497,13 @@ class MCPToolManageService: ) as mcp_client: return mcp_client.list_tools() - def execute_auth_actions(self, auth_result: Any) -> dict[str, str]: + _ACTION_TO_OAUTH: dict[AuthActionType, OAuthDataType] = { + AuthActionType.SAVE_CLIENT_INFO: OAuthDataType.CLIENT_INFO, + AuthActionType.SAVE_TOKENS: OAuthDataType.TOKENS, + AuthActionType.SAVE_CODE_VERIFIER: OAuthDataType.CODE_VERIFIER, + } + + def execute_auth_actions(self, auth_result: AuthResult) -> dict[str, str]: """ Execute the actions returned by the auth function. @@ -508,19 +515,13 @@ class MCPToolManageService: Returns: The response from the auth result """ - from core.mcp.entities import AuthAction, AuthActionType - - action: AuthAction for action in auth_result.actions: if action.provider_id is None or action.tenant_id is None: continue - if action.action_type == AuthActionType.SAVE_CLIENT_INFO: - self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CLIENT_INFO) - elif action.action_type == AuthActionType.SAVE_TOKENS: - self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.TOKENS) - elif action.action_type == AuthActionType.SAVE_CODE_VERIFIER: - self.save_oauth_data(action.provider_id, action.tenant_id, action.data, OAuthDataType.CODE_VERIFIER) + oauth_type = self._ACTION_TO_OAUTH.get(action.action_type) + if oauth_type is not None: + self.save_oauth_data(action.provider_id, action.tenant_id, action.data, oauth_type) return auth_result.response diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 4fd2ea1628..47aca9b0af 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, Union +from typing import Any from pydantic import TypeAdapter, ValidationError from yarl import URL @@ -69,7 +69,9 @@ class ToolTransformService: return "" @staticmethod - def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]): + def repack_provider( + tenant_id: str, provider: dict[str, Any] | ToolProviderApiEntity | PluginDatasourceProviderEntity + ): """ repack provider @@ -426,7 +428,7 @@ class ToolTransformService: @staticmethod def convert_builtin_provider_to_credential_entity( - provider: BuiltinToolProvider, credentials: dict + provider: BuiltinToolProvider, credentials: dict[str, Any] ) -> ToolProviderCredentialApiEntity: return ToolProviderCredentialApiEntity( id=provider.id, diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 8f5144c866..8f6600af03 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,10 +1,10 @@ import json import logging from datetime import datetime +from typing import Any -from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import delete, or_, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -14,6 +14,7 @@ from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurati from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db +from graphon.model_runtime.utils.encoders import jsonable_encoder from models.model import App from models.tools import WorkflowToolProvider from models.workflow import Workflow @@ -35,39 +36,50 @@ class WorkflowToolManageService: workflow_app_id: str, name: str, label: str, - icon: dict, + icon: dict[str, Any], description: str, parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", labels: list[str] | None = None, ): # check if the name is unique - existing_workflow_tool_provider = db.session.scalar( - select(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == tenant_id, - # name or app_id - or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + existing_workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + # query if the name or app_id exists + existing_workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + ) + .limit(1) ) - .limit(1) - ) + # if the name or app_id exists raise error if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") - app: App | None = db.session.scalar( - select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1) - ) + # query the app + app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + app = _session.scalar(select(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).limit(1)) + # if not found raise error if app is None: raise ValueError(f"App {workflow_app_id} not found") + # query the workflow workflow: Workflow | None = app.workflow + + # if not found raise error if workflow is None: raise ValueError(f"Workflow not found for app {workflow_app_id}") + # check if workflow configuration is synced WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) + # create workflow tool provider workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -84,15 +96,18 @@ class WorkflowToolManageService: try: WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: + logger.warning(e, exc_info=True) raise ValueError(str(e)) - with Session(db.engine, expire_on_commit=False) as session, session.begin(): - session.add(workflow_tool_provider) + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + _session.add(workflow_tool_provider) + # keep the session open to make orm instances in the same session if labels is not None: ToolLabelManager.update_tool_labels( ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) + return {"result": "success"} @classmethod @@ -103,7 +118,7 @@ class WorkflowToolManageService: workflow_tool_id: str, name: str, label: str, - icon: dict, + icon: dict[str, Any], description: str, parameters: list[WorkflowToolParameterConfiguration], privacy_policy: str = "", @@ -111,6 +126,7 @@ class WorkflowToolManageService: ): """ Update a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: workflow tool id @@ -123,62 +139,82 @@ class WorkflowToolManageService: :param labels: labels :return: the updated tool """ - # check if the name is unique - existing_workflow_tool_provider = db.session.scalar( - select(WorkflowToolProvider) - .where( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.name == name, - WorkflowToolProvider.id != workflow_tool_id, - ) - .limit(1) - ) + existing_workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + # query if the name exists for other tools + existing_workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id, + ) + .limit(1) + ) + + # if the name exists raise error if existing_workflow_tool_provider is not None: raise ValueError(f"Tool with name {name} already exists") - workflow_tool_provider: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .limit(1) - ) + # query the workflow tool provider + workflow_tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + workflow_tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + # if not found raise error if workflow_tool_provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - app: App | None = db.session.scalar( - select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1) - ) + # query the app + app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + app = _session.scalar( + select(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).limit(1) + ) + # if not found raise error if app is None: raise ValueError(f"App {workflow_tool_provider.app_id} not found") + # query the workflow workflow: Workflow | None = app.workflow + + # if not found raise error if workflow is None: raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + # check if workflow configuration is synced WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(workflow.graph_dict) - workflow_tool_provider.name = name - workflow_tool_provider.label = label - workflow_tool_provider.icon = json.dumps(icon) - workflow_tool_provider.description = description - workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) - workflow_tool_provider.privacy_policy = privacy_policy - workflow_tool_provider.version = workflow.version - workflow_tool_provider.updated_at = datetime.now() + with sessionmaker(db.engine).begin() as _session: + _session.add(workflow_tool_provider) - try: - WorkflowToolProviderController.from_db(workflow_tool_provider) - except Exception as e: - raise ValueError(str(e)) + # update workflow tool provider + workflow_tool_provider.name = name + workflow_tool_provider.label = label + workflow_tool_provider.icon = json.dumps(icon) + workflow_tool_provider.description = description + workflow_tool_provider.parameter_configuration = json.dumps([p.model_dump() for p in parameters]) + workflow_tool_provider.privacy_policy = privacy_policy + workflow_tool_provider.version = workflow.version + workflow_tool_provider.updated_at = datetime.now() - db.session.commit() + try: + WorkflowToolProviderController.from_db(workflow_tool_provider) + except Exception as e: + raise ValueError(str(e)) - if labels is not None: - ToolLabelManager.update_tool_labels( - ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels - ) + if labels is not None: + ToolLabelManager.update_tool_labels( + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), + labels, + session=_session, + ) return {"result": "success"} @@ -186,28 +222,32 @@ class WorkflowToolManageService: def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: """ List workflow tools. + :param user_id: the user id :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.scalars( - select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id) - ).all() + + providers: list[WorkflowToolProvider] = [] + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + providers = list( + _session.scalars(select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)).all() + ) # Create a mapping from provider_id to app_id - provider_id_to_app_id = {provider.id: provider.app_id for provider in db_tools} + provider_id_to_app_id = {provider.id: provider.app_id for provider in providers} tools: list[WorkflowToolProviderController] = [] - for provider in db_tools: + for provider in providers: try: tools.append(ToolTransformService.workflow_provider_to_controller(provider)) except Exception: # skip deleted tools logger.exception("Failed to load workflow tool provider %s", provider.id) - labels = ToolLabelManager.get_tools_labels([t for t in tools if isinstance(t, ToolProviderController)]) + labels = ToolLabelManager.get_tools_labels([tool for tool in tools if isinstance(tool, ToolProviderController)]) - result = [] + result: list[ToolProviderApiEntity] = [] for tool in tools: workflow_app_id = provider_id_to_app_id.get(tool.provider_id) @@ -232,17 +272,18 @@ class WorkflowToolManageService: def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str): """ Delete a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id """ - db.session.execute( - delete(WorkflowToolProvider).where( - WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id - ) - ) - db.session.commit() + with sessionmaker(db.engine).begin() as _session: + _ = _session.execute( + delete(WorkflowToolProvider).where( + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id + ) + ) return {"result": "success"} @@ -250,47 +291,59 @@ class WorkflowToolManageService: def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str): """ Get a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id :return: the tool """ - db_tool: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .limit(1) - ) - return cls._get_workflow_tool(tenant_id, db_tool) + + tool_provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + tool_provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + + return cls._get_workflow_tool(tenant_id, tool_provider) @classmethod def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str): """ Get a workflow tool. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) - .limit(1) - ) - return cls._get_workflow_tool(tenant_id, db_tool) + + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + tool_provider: WorkflowToolProvider | None = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .limit(1) + ) + + return cls._get_workflow_tool(tenant_id, tool_provider) @classmethod def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None): """ Get a workflow tool. + :db_tool: the database tool :return: the tool """ if db_tool is None: raise ValueError("Tool not found") - workflow_app: App | None = db.session.scalar( - select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1) - ) + workflow_app: App | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + workflow_app = _session.scalar( + select(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).limit(1) + ) if workflow_app is None: raise ValueError(f"App {db_tool.app_id} not found") @@ -330,28 +383,32 @@ class WorkflowToolManageService: def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]: """ List workflow tool provider tools. + :param user_id: the user id :param tenant_id: the tenant id :param workflow_tool_id: the workflow tool id :return: the list of tools """ - db_tool: WorkflowToolProvider | None = db.session.scalar( - select(WorkflowToolProvider) - .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) - .limit(1) - ) - if db_tool is None: + provider: WorkflowToolProvider | None = None + with sessionmaker(db.engine, expire_on_commit=False).begin() as _session: + provider = _session.scalar( + select(WorkflowToolProvider) + .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .limit(1) + ) + + if provider is None: raise ValueError(f"Tool {workflow_tool_id} not found") - tool = ToolTransformService.workflow_provider_to_controller(db_tool) + tool = ToolTransformService.workflow_provider_to_controller(provider) workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id) if len(workflow_tools) == 0: raise ValueError(f"Tool {workflow_tool_id} not found") return [ ToolTransformService.convert_tool_entity_to_api_entity( - tool=tool.get_tools(db_tool.tenant_id)[0], + tool=tool.get_tools(provider.tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool), tenant_id=tenant_id, ) diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index 25e80770b8..a827222c1d 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -2,7 +2,6 @@ import json import logging from datetime import datetime -from graphon.entities.graph_config import NodeConfigDict from sqlalchemy import select from sqlalchemy.orm import Session @@ -14,6 +13,7 @@ from core.workflow.nodes.trigger_schedule.entities import ( VisualConfig, ) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError +from graphon.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index ae74f7a8cd..b8a76e4945 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -3,9 +3,9 @@ import logging import time as _time import uuid from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict -from sqlalchemy import desc, func +from sqlalchemy import delete, desc, func, select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config @@ -14,7 +14,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler -from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params +from core.tools.utils.system_encryption import decrypt_system_params from core.trigger.entities.api_entities import ( TriggerProviderApiEntity, TriggerProviderSubscriptionApiEntity, @@ -42,6 +42,10 @@ from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) +class VerifyCredentialsResult(TypedDict): + verified: bool + + class TriggerProviderService: """Service for managing trigger providers and credentials""" @@ -69,27 +73,28 @@ class TriggerProviderService: workflows_in_use_map: dict[str, int] = {} with Session(db.engine, expire_on_commit=False) as session: # Get all subscriptions - subscriptions_db = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id)) + subscriptions_db = session.scalars( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + ) .order_by(desc(TriggerSubscription.created_at)) - .all() - ) + ).all() subscriptions = [subscription.to_api_entity() for subscription in subscriptions_db] if not subscriptions: return [] - usage_counts = ( - session.query( + usage_counts = session.execute( + select( WorkflowPluginTrigger.subscription_id, func.count(func.distinct(WorkflowPluginTrigger.app_id)).label("app_count"), ) - .filter( + .where( WorkflowPluginTrigger.tenant_id == tenant_id, WorkflowPluginTrigger.subscription_id.in_([s.id for s in subscriptions]), ) .group_by(WorkflowPluginTrigger.subscription_id) - .all() - ) + ).all() workflows_in_use_map = {str(row.subscription_id): int(row.app_count) for row in usage_counts} provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) @@ -152,9 +157,13 @@ class TriggerProviderService: with redis_client.lock(lock_key, timeout=20): # Check provider count limit provider_count = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id)) - .count() + session.scalar( + select(func.count(TriggerSubscription.id)).where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + ) + ) + or 0 ) if provider_count >= cls.__MAX_TRIGGER_PROVIDER_COUNT__: @@ -164,10 +173,14 @@ class TriggerProviderService: ) # Check if name already exists - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) - .first() + existing = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + TriggerSubscription.name == name, + ) + .limit(1) ) if existing: raise ValueError(f"Credential name '{name}' already exists for this provider") @@ -244,8 +257,13 @@ class TriggerProviderService: # Use distributed lock to prevent race conditions on the same subscription lock_key = f"trigger_subscription_update_lock:{tenant_id}_{subscription_id}" with redis_client.lock(lock_key, timeout=20): - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if not subscription: raise ValueError(f"Trigger subscription {subscription_id} not found") @@ -255,10 +273,14 @@ class TriggerProviderService: # Check for name uniqueness if name is being updated if name is not None and name != subscription.name: - existing = ( - session.query(TriggerSubscription) - .filter_by(tenant_id=tenant_id, provider_id=str(provider_id), name=name) - .first() + existing = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.provider_id == str(provider_id), + TriggerSubscription.name == name, + ) + .limit(1) ) if existing: raise ValueError(f"Subscription name '{name}' already exists for this provider") @@ -316,11 +338,18 @@ class TriggerProviderService: with Session(db.engine, expire_on_commit=False) as session: subscription: TriggerSubscription | None = None if subscription_id: - subscription = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) else: - subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id).first() + subscription = session.scalar( + select(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant_id).limit(1) + ) if subscription: provider_controller = TriggerManager.get_trigger_provider( tenant_id, TriggerProviderID(subscription.provider_id) @@ -349,8 +378,13 @@ class TriggerProviderService: :param subscription_id: Subscription instance ID :return: Success response """ - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -402,7 +436,14 @@ class TriggerProviderService: :return: New token info """ with sessionmaker(bind=db.engine).begin() as session: - subscription = session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) + ) if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -475,8 +516,13 @@ class TriggerProviderService: now_ts: int = int(now if now is not None else _time.time()) with sessionmaker(bind=db.engine).begin() as session: - subscription: TriggerSubscription | None = ( - session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + subscription = session.scalar( + select(TriggerSubscription) + .where( + TriggerSubscription.tenant_id == tenant_id, + TriggerSubscription.id == subscription_id, + ) + .limit(1) ) if subscription is None: raise ValueError(f"Trigger provider subscription {subscription_id} not found") @@ -552,15 +598,15 @@ class TriggerProviderService: tenant_id=tenant_id, provider_id=provider_id ) with Session(db.engine, expire_on_commit=False) as session: - tenant_client: TriggerOAuthTenantClient | None = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - enabled=True, + tenant_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) oauth_params: Mapping[str, Any] | None = None @@ -578,15 +624,18 @@ class TriggerProviderService: return None # Check for system-level OAuth client - system_client: TriggerOAuthSystemClient | None = ( - session.query(TriggerOAuthSystemClient) - .filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name) - .first() + system_client = session.scalar( + select(TriggerOAuthSystemClient) + .where( + TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id, + TriggerOAuthSystemClient.provider == provider_id.provider_name, + ) + .limit(1) ) if system_client: try: - oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + oauth_params = decrypt_system_params(system_client.encrypted_oauth_params) except Exception as e: raise ValueError(f"Error decrypting system oauth params: {e}") @@ -602,10 +651,13 @@ class TriggerProviderService: if not is_verified: return False with Session(db.engine, expire_on_commit=False) as session: - system_client: TriggerOAuthSystemClient | None = ( - session.query(TriggerOAuthSystemClient) - .filter_by(plugin_id=provider_id.plugin_id, provider=provider_id.provider_name) - .first() + system_client = session.scalar( + select(TriggerOAuthSystemClient) + .where( + TriggerOAuthSystemClient.plugin_id == provider_id.plugin_id, + TriggerOAuthSystemClient.provider == provider_id.provider_name, + ) + .limit(1) ) return system_client is not None @@ -636,14 +688,14 @@ class TriggerProviderService: with sessionmaker(bind=db.engine).begin() as session: # Find existing custom client params - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, ) - .first() + .limit(1) ) # Create new record if doesn't exist @@ -690,14 +742,14 @@ class TriggerProviderService: :return: Masked OAuth client parameters """ with Session(db.engine) as session: - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, ) - .first() + .limit(1) ) if custom_client is None: @@ -727,11 +779,15 @@ class TriggerProviderService: :return: Success response """ with sessionmaker(bind=db.engine).begin() as session: - session.query(TriggerOAuthTenantClient).filter_by( - tenant_id=tenant_id, - provider=provider_id.provider_name, - plugin_id=provider_id.plugin_id, - ).delete() + session.execute( + delete(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + ) + .execution_options(synchronize_session=False) + ) return {"result": "success"} @@ -745,15 +801,15 @@ class TriggerProviderService: :return: True if enabled, False otherwise """ with Session(db.engine, expire_on_commit=False) as session: - custom_client = ( - session.query(TriggerOAuthTenantClient) - .filter_by( - tenant_id=tenant_id, - plugin_id=provider_id.plugin_id, - provider=provider_id.provider_name, - enabled=True, + custom_client = session.scalar( + select(TriggerOAuthTenantClient) + .where( + TriggerOAuthTenantClient.tenant_id == tenant_id, + TriggerOAuthTenantClient.plugin_id == provider_id.plugin_id, + TriggerOAuthTenantClient.provider == provider_id.provider_name, + TriggerOAuthTenantClient.enabled.is_(True), ) - .first() + .limit(1) ) return custom_client is not None @@ -763,7 +819,9 @@ class TriggerProviderService: Get a trigger subscription by the endpoint ID. """ with Session(db.engine, expire_on_commit=False) as session: - subscription = session.query(TriggerSubscription).filter_by(endpoint_id=endpoint_id).first() + subscription = session.scalar( + select(TriggerSubscription).where(TriggerSubscription.endpoint_id == endpoint_id).limit(1) + ) if not subscription: return None provider_controller: PluginTriggerProviderController = TriggerManager.get_trigger_provider( @@ -792,7 +850,7 @@ class TriggerProviderService: provider_id: TriggerProviderID, subscription_id: str, credentials: dict[str, Any], - ) -> dict[str, Any]: + ) -> VerifyCredentialsResult: """ Verify credentials for an existing subscription without updating it. diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 5a5d13b96d..911331e357 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -5,7 +5,6 @@ from collections.abc import Mapping from typing import Any from flask import Request, Response -from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -21,6 +20,7 @@ from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_ from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from extensions.ext_database import db from extensions.ext_redis import redis_client +from graphon.entities.graph_config import NodeConfigDict from models.model import App from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription, WorkflowPluginTrigger diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index bb767a6759..5d99900a04 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -7,9 +7,6 @@ from typing import Any, NotRequired, TypedDict import orjson from flask import request -from graphon.entities.graph_config import NodeConfigDict -from graphon.file import FileTransferMethod -from graphon.variables.types import ArrayValidation, SegmentType from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -31,6 +28,9 @@ from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory +from graphon.entities.graph_config import NodeConfigDict +from graphon.file import FileTransferMethod +from graphon.variables.types import ArrayValidation, SegmentType from models.enums import AppTriggerStatus, AppTriggerType from models.model import App from models.trigger import AppTrigger, WorkflowWebhookTrigger @@ -38,6 +38,7 @@ from models.workflow import Workflow from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService from services.trigger.app_trigger_service import AppTriggerService from services.workflow.entities import WebhookTriggerData @@ -798,45 +799,47 @@ class WebhookService: Exception: If workflow execution fails """ try: - with Session(db.engine) as session: - # Prepare inputs for the webhook node - # The webhook node expects webhook_data in the inputs - workflow_inputs = cls.build_workflow_inputs(webhook_data) + workflow_inputs = cls.build_workflow_inputs(webhook_data) - # Create trigger data - trigger_data = WebhookTriggerData( - app_id=webhook_trigger.app_id, - workflow_id=workflow.id, - root_node_id=webhook_trigger.node_id, # Start from the webhook node - inputs=workflow_inputs, - tenant_id=webhook_trigger.tenant_id, + trigger_data = WebhookTriggerData( + app_id=webhook_trigger.app_id, + workflow_id=workflow.id, + root_node_id=webhook_trigger.node_id, + inputs=workflow_inputs, + tenant_id=webhook_trigger.tenant_id, + ) + + end_user = EndUserService.get_or_create_end_user_by_type( + type=InvokeFrom.TRIGGER, + tenant_id=webhook_trigger.tenant_id, + app_id=webhook_trigger.app_id, + user_id=None, + ) + + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, webhook_trigger.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) + logger.info( + "Tenant %s rate limited, skipping webhook trigger %s", + webhook_trigger.tenant_id, + webhook_trigger.webhook_id, ) + raise - end_user = EndUserService.get_or_create_end_user_by_type( - type=InvokeFrom.TRIGGER, - tenant_id=webhook_trigger.tenant_id, - app_id=webhook_trigger.app_id, - user_id=None, - ) - - # consume quota before triggering workflow execution - try: - QuotaType.TRIGGER.consume(webhook_trigger.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(webhook_trigger.tenant_id) - logger.info( - "Tenant %s rate limited, skipping webhook trigger %s", - webhook_trigger.tenant_id, - webhook_trigger.webhook_id, + try: + # NOTE: don not use `with sessionmaker(bind=db.engine, expire_on_commit=False).begin()` + # trigger_workflow_async need to handle multipe session commits internally + with Session(db.engine, expire_on_commit=False) as session: + AsyncWorkflowService.trigger_workflow_async( + session, + end_user, + trigger_data, ) - raise - - # Trigger workflow execution asynchronously - AsyncWorkflowService.trigger_workflow_async( - session, - end_user, - trigger_data, - ) + quota_charge.commit() + except Exception: + quota_charge.refund() + raise except Exception: logger.exception("Failed to trigger workflow for webhook %s", webhook_trigger.webhook_id) diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 4d58a9cf12..1529c2b98f 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, overload +from configs import dify_config from graphon.file import File from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable from graphon.variables.segments import ( @@ -21,8 +22,6 @@ from graphon.variables.segments import ( ) from graphon.variables.utils import dumps_with_segments -from configs import dify_config - _MAX_DEPTH = 100 @@ -170,7 +169,7 @@ class VariableTruncator(BaseTruncator): return TruncationResult(StringSegment(value=fallback_result.value), True) # Apply final fallback - convert to JSON string and truncate - json_str = dumps_with_segments(result.value, ensure_ascii=False) + json_str = dumps_with_segments(result.value) if len(json_str) > self._max_size_bytes: json_str = json_str[: self._max_size_bytes] + "..." return TruncationResult(result=StringSegment(value=json_str), truncated=True) diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 9827c8dfbc..7e689af35d 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,6 +1,5 @@ import logging -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import delete, select from core.model_manager import ModelInstance, ModelManager @@ -13,9 +12,11 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db +from graphon.model_runtime.entities.model_entities import ModelType from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument +from models.enums import SegmentType logger = logging.getLogger(__name__) @@ -178,7 +179,7 @@ class VectorService: index_node_hash=child_chunk.metadata["doc_hash"], content=child_chunk.page_content, word_count=len(child_chunk.page_content), - type="automatic", + type=SegmentType.AUTOMATIC, created_by=dataset_document.created_by, ) db.session.add(child_segment) @@ -222,6 +223,7 @@ class VectorService: ) documents.append(new_child_document) for update_child_chunk in update_child_chunks: + assert update_child_chunk.index_node_id child_document = Document( page_content=update_child_chunk.content, metadata={ @@ -234,6 +236,7 @@ class VectorService: documents.append(child_document) delete_node_ids.append(update_child_chunk.index_node_id) for delete_child_chunk in delete_child_chunks: + assert delete_child_chunk.index_node_id delete_node_ids.append(delete_child_chunk.index_node_id) if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # update vector index @@ -246,6 +249,7 @@ class VectorService: @classmethod def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): vector = Vector(dataset=dataset) + assert child_chunk.index_node_id vector.delete_by_ids([child_chunk.index_node_id]) @classmethod diff --git a/api/services/website_service.py b/api/services/website_service.py index 2471c2cee8..ea584088bb 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -91,7 +91,7 @@ class WebsiteCrawlApiRequest: return CrawlRequest(url=self.url, provider=self.provider, options=options) @classmethod - def from_args(cls, args: dict) -> WebsiteCrawlApiRequest: + def from_args(cls, args: dict[str, Any]) -> WebsiteCrawlApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") url = args.get("url") @@ -115,7 +115,7 @@ class WebsiteCrawlStatusApiRequest: job_id: str @classmethod - def from_args(cls, args: dict, job_id: str) -> WebsiteCrawlStatusApiRequest: + def from_args(cls, args: dict[str, Any], job_id: str) -> WebsiteCrawlStatusApiRequest: """Create from Flask-RESTful parsed arguments.""" provider = args.get("provider") if not provider: @@ -163,7 +163,7 @@ class WebsiteService: raise ValueError("Invalid provider") @classmethod - def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str: + def _get_decrypted_api_key(cls, tenant_id: str, config: dict[str, Any]) -> str: """Decrypt and return the API key from config.""" api_key = config.get("api_key") if not api_key: @@ -171,7 +171,7 @@ class WebsiteService: return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key) @classmethod - def document_create_args_validate(cls, args: dict): + def document_create_args_validate(cls, args: dict[str, Any]): """Validate arguments for document creation.""" try: WebsiteCrawlApiRequest.from_args(args) @@ -195,7 +195,7 @@ class WebsiteService: raise ValueError("Invalid provider") @classmethod - def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: + def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) params: dict[str, Any] @@ -225,7 +225,7 @@ class WebsiteService: return {"status": "active", "job_id": job_id} @classmethod - def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: + def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]: # Convert CrawlOptions back to dict format for WaterCrawlProvider options = { "limit": request.options.limit, @@ -290,7 +290,7 @@ class WebsiteService: raise ValueError("Invalid provider") @classmethod - def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> CrawlStatusDict: + def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> CrawlStatusDict: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id) crawl_status_data: CrawlStatusDict = { @@ -364,7 +364,9 @@ class WebsiteService: raise ValueError("Invalid provider") @classmethod - def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: + def _get_firecrawl_url_data( + cls, job_id: str, url: str, api_key: str, config: dict[str, Any] + ) -> dict[str, Any] | None: crawl_data: list[FirecrawlDocumentData] | None = None file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): @@ -438,7 +440,7 @@ class WebsiteService: raise ValueError("Invalid provider") @classmethod - def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: + def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) params = {"onlyMainContent": request.only_main_content} return dict(firecrawl_app.scrape_url(url=request.url, params=params)) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1582bcd46c..5dedb9e372 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,11 +1,6 @@ import json from typing import Any, TypedDict -from graphon.file import FileUploadConfig -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.utils.encoders import jsonable_encoder -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.input_entities import VariableEntity from sqlalchemy import select from core.app.app_config.entities import ( @@ -24,6 +19,11 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.app_event import app_was_created from extensions.ext_database import db +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig, IconType diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index b5ab176ad2..59e02ec9b9 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -3,10 +3,10 @@ import uuid from datetime import datetime from typing import Any, TypedDict -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session +from graphon.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog diff --git a/api/services/workflow_collaboration_service.py b/api/services/workflow_collaboration_service.py new file mode 100644 index 0000000000..cf2f509052 --- /dev/null +++ b/api/services/workflow_collaboration_service.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import logging +import time +from collections.abc import Mapping + +from sqlalchemy import select + +from core.db.session_factory import session_factory +from models.account import Account +from models.model import App +from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository, WorkflowSessionInfo + +logger = logging.getLogger(__name__) + + +class WorkflowCollaborationService: + def __init__(self, repository: WorkflowCollaborationRepository, socketio) -> None: + self._repository = repository + self._socketio = socketio + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(repository={self._repository})" + + def save_socket_identity(self, sid: str, user: Account) -> None: + """Persist the authenticated console user on the raw socket session.""" + self._socketio.save_session( + sid, + { + "user_id": user.id, + "username": user.name, + "avatar": user.avatar, + "tenant_id": user.current_tenant_id, + }, + ) + + def authorize_and_join_workflow_room(self, workflow_id: str, sid: str) -> tuple[str, bool] | None: + """ + Join a collaboration room only after validating the socket session and tenant-scoped app access. + + The Socket.IO payload still calls the room key `workflow_id`, but the identifier is the workflow app's + `App.id`. Returning `None` lets the controller reject the join before any Redis or room state is created. + """ + session = self._socketio.get_session(sid) + user_id = session.get("user_id") + tenant_id = session.get("tenant_id") + if not user_id or not tenant_id: + return None + + if not self._can_access_workflow(workflow_id, str(tenant_id)): + logger.warning( + "Workflow collaboration join rejected: workflow_id=%s tenant_id=%s user_id=%s sid=%s", + workflow_id, + tenant_id, + user_id, + sid, + ) + return None + + session_info: WorkflowSessionInfo = { + "user_id": str(user_id), + "username": str(session.get("username", "Unknown")), + "avatar": session.get("avatar"), + "sid": sid, + "connected_at": int(time.time()), + } + + self._repository.set_session_info(workflow_id, session_info) + + leader_sid = self.get_or_set_leader(workflow_id, sid) + is_leader = leader_sid == sid + + self._socketio.enter_room(sid, workflow_id) + self.broadcast_online_users(workflow_id) + + self._socketio.emit("status", {"isLeader": is_leader}, room=sid) + + return str(user_id), is_leader + + def _can_access_workflow(self, workflow_id: str, tenant_id: str) -> bool: + """Check room access without relying on Flask's app-context-bound scoped session.""" + with session_factory.create_session() as session: + app_id = session.scalar(select(App.id).where(App.id == workflow_id, App.tenant_id == tenant_id).limit(1)) + return app_id is not None + + def disconnect_session(self, sid: str) -> None: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return + + workflow_id = mapping["workflow_id"] + self._repository.delete_session(workflow_id, sid) + + self.handle_leader_disconnect(workflow_id, sid) + self.broadcast_online_users(workflow_id) + + def relay_collaboration_event(self, sid: str, data: Mapping[str, object]) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + user_id = mapping["user_id"] + self.refresh_session_state(workflow_id, sid) + + event_type = data.get("type") + event_data = data.get("data") + timestamp = data.get("timestamp", int(time.time())) + + if not event_type: + return {"msg": "invalid event type"}, 400 + + if event_type == "sync_request": + leader_sid = self._repository.get_current_leader(workflow_id) + target_sid: str | None + if leader_sid and self.is_session_active(workflow_id, leader_sid): + target_sid = leader_sid + else: + if leader_sid: + self._repository.delete_leader(workflow_id) + target_sid = self._select_graph_leader(workflow_id, preferred_sid=sid) + if target_sid: + self._repository.set_leader(workflow_id, target_sid) + self.broadcast_leader_change(workflow_id, target_sid) + + if not target_sid: + return {"msg": "no_active_leader"}, 200 + + self._socketio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=target_sid, + ) + + return {"msg": "sync_request_forwarded"}, 200 + + self._socketio.emit( + "collaboration_update", + {"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp}, + room=workflow_id, + skip_sid=sid, + ) + + return {"msg": "event_broadcasted"}, 200 + + def relay_graph_event(self, sid: str, data: object) -> tuple[dict[str, str], int]: + mapping = self._repository.get_sid_mapping(sid) + if not mapping: + return {"msg": "unauthorized"}, 401 + + workflow_id = mapping["workflow_id"] + self.refresh_session_state(workflow_id, sid) + + self._socketio.emit("graph_update", data, room=workflow_id, skip_sid=sid) + + return {"msg": "graph_update_broadcasted"}, 200 + + def get_or_set_leader(self, workflow_id: str, sid: str) -> str: + current_leader = self._repository.get_current_leader(workflow_id) + + if current_leader: + if self.is_session_active(workflow_id, current_leader): + return current_leader + self._repository.delete_session(workflow_id, current_leader) + self._repository.delete_leader(workflow_id) + + was_set = self._repository.set_leader_if_absent(workflow_id, sid) + + if was_set: + if current_leader: + self.broadcast_leader_change(workflow_id, sid) + return sid + + current_leader = self._repository.get_current_leader(workflow_id) + if current_leader: + return current_leader + + return sid + + def handle_leader_disconnect(self, workflow_id: str, disconnected_sid: str) -> None: + current_leader = self._repository.get_current_leader(workflow_id) + if not current_leader: + return + + if current_leader != disconnected_sid: + return + + new_leader_sid = self._select_graph_leader(workflow_id) + if new_leader_sid: + self._repository.set_leader(workflow_id, new_leader_sid) + self.broadcast_leader_change(workflow_id, new_leader_sid) + else: + self._repository.delete_leader(workflow_id) + + def broadcast_leader_change(self, workflow_id: str, new_leader_sid: str | None) -> None: + for sid in self._repository.get_session_sids(workflow_id): + try: + is_leader = new_leader_sid is not None and sid == new_leader_sid + self._socketio.emit("status", {"isLeader": is_leader}, room=sid) + except Exception: + logging.exception("Failed to emit leader status to session %s", sid) + + def get_current_leader(self, workflow_id: str) -> str | None: + return self._repository.get_current_leader(workflow_id) + + def _prune_inactive_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]: + """Remove inactive sessions from storage and return active sessions only.""" + sessions = self._repository.list_sessions(workflow_id) + if not sessions: + return [] + + active_sessions: list[WorkflowSessionInfo] = [] + stale_sids: list[str] = [] + for session in sessions: + sid = session["sid"] + if self.is_session_active(workflow_id, sid): + active_sessions.append(session) + else: + stale_sids.append(sid) + + for sid in stale_sids: + self._repository.delete_session(workflow_id, sid) + + return active_sessions + + def broadcast_online_users(self, workflow_id: str) -> None: + users = self._prune_inactive_sessions(workflow_id) + users.sort(key=lambda x: x.get("connected_at") or 0) + + leader_sid = self.get_current_leader(workflow_id) + previous_leader = leader_sid + active_sids = {user["sid"] for user in users} + if leader_sid and leader_sid not in active_sids: + self._repository.delete_leader(workflow_id) + leader_sid = None + + if not leader_sid and users: + leader_sid = self._select_graph_leader(workflow_id) + if leader_sid: + self._repository.set_leader(workflow_id, leader_sid) + + if leader_sid != previous_leader: + self.broadcast_leader_change(workflow_id, leader_sid) + + self._socketio.emit( + "online_users", + {"workflow_id": workflow_id, "users": users, "leader": leader_sid}, + room=workflow_id, + ) + + def refresh_session_state(self, workflow_id: str, sid: str) -> None: + self._repository.refresh_session_state(workflow_id, sid) + self._ensure_leader(workflow_id, sid) + + def _ensure_leader(self, workflow_id: str, sid: str) -> None: + current_leader = self._repository.get_current_leader(workflow_id) + if current_leader and self.is_session_active(workflow_id, current_leader): + self._repository.expire_leader(workflow_id) + return + + if current_leader: + self._repository.delete_leader(workflow_id) + + self._repository.set_leader(workflow_id, sid) + self.broadcast_leader_change(workflow_id, sid) + + def _select_graph_leader(self, workflow_id: str, preferred_sid: str | None = None) -> str | None: + session_sids = [ + session["sid"] + for session in self._repository.list_sessions(workflow_id) + if session.get("graph_active", True) and self.is_session_active(workflow_id, session["sid"]) + ] + if not session_sids: + return None + if preferred_sid and preferred_sid in session_sids: + return preferred_sid + return session_sids[0] + + def is_session_active(self, workflow_id: str, sid: str) -> bool: + if not sid: + return False + + try: + if not self._socketio.manager.is_connected(sid, "/"): + return False + except AttributeError: + return False + + if not self._repository.session_exists(workflow_id, sid): + return False + + if not self._repository.sid_mapping_exists(sid): + return False + + return True diff --git a/api/services/workflow_comment_service.py b/api/services/workflow_comment_service.py new file mode 100644 index 0000000000..ff47e4f253 --- /dev/null +++ b/api/services/workflow_comment_service.py @@ -0,0 +1,564 @@ +import logging +from collections.abc import Sequence + +from sqlalchemy import desc, select +from sqlalchemy.orm import Session, selectinload +from werkzeug.exceptions import Forbidden, NotFound + +from configs import dify_config +from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now +from libs.helper import uuid_value +from models import App, TenantAccountJoin, WorkflowComment, WorkflowCommentMention, WorkflowCommentReply +from models.account import Account +from tasks.mail_workflow_comment_task import send_workflow_comment_mention_email_task + +logger = logging.getLogger(__name__) + + +class WorkflowCommentService: + """Service for managing workflow comments.""" + + @staticmethod + def _validate_content(content: str) -> None: + if len(content.strip()) == 0: + raise ValueError("Comment content cannot be empty") + + if len(content) > 1000: + raise ValueError("Comment content cannot exceed 1000 characters") + + @staticmethod + def _filter_valid_mentioned_user_ids( + mentioned_user_ids: Sequence[str], *, session: Session, tenant_id: str + ) -> list[str]: + """Return deduplicated UUID user IDs that belong to the tenant, preserving input order.""" + unique_user_ids: list[str] = [] + seen: set[str] = set() + for user_id in mentioned_user_ids: + if not isinstance(user_id, str): + continue + if not uuid_value(user_id): + continue + if user_id in seen: + continue + seen.add(user_id) + unique_user_ids.append(user_id) + if not unique_user_ids: + return [] + + tenant_member_ids = { + str(account_id) + for account_id in session.scalars( + select(TenantAccountJoin.account_id).where( + TenantAccountJoin.tenant_id == tenant_id, + TenantAccountJoin.account_id.in_(unique_user_ids), + ) + ).all() + } + + return [user_id for user_id in unique_user_ids if user_id in tenant_member_ids] + + @staticmethod + def _format_comment_excerpt(content: str, max_length: int = 200) -> str: + """Trim comment content for email display.""" + trimmed = content.strip() + if len(trimmed) <= max_length: + return trimmed + if max_length <= 3: + return trimmed[:max_length] + return f"{trimmed[: max_length - 3].rstrip()}..." + + @staticmethod + def _build_mention_email_payloads( + session: Session, + tenant_id: str, + app_id: str, + mentioner_id: str, + mentioned_user_ids: Sequence[str], + content: str, + ) -> list[dict[str, str]]: + """Prepare email payloads for mentioned users, including workflow app link.""" + if not mentioned_user_ids: + return [] + + candidate_user_ids = [user_id for user_id in mentioned_user_ids if user_id != mentioner_id] + if not candidate_user_ids: + return [] + + app_name_value = session.scalar(select(App.name).where(App.id == app_id, App.tenant_id == tenant_id)) + app_name = app_name_value if isinstance(app_name_value, str) and app_name_value else "Dify app" + commenter_name_value = session.scalar(select(Account.name).where(Account.id == mentioner_id)) + commenter_name = ( + commenter_name_value if isinstance(commenter_name_value, str) and commenter_name_value else "Dify user" + ) + comment_excerpt = WorkflowCommentService._format_comment_excerpt(content) + base_url = dify_config.CONSOLE_WEB_URL.rstrip("/") + app_url = f"{base_url}/app/{app_id}/workflow" + + accounts = session.scalars( + select(Account) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where(TenantAccountJoin.tenant_id == tenant_id, Account.id.in_(candidate_user_ids)) + ).all() + + payloads: list[dict[str, str]] = [] + for account in accounts: + email = account.email + if not isinstance(email, str) or not email: + continue + mentioned_name = account.name if isinstance(account.name, str) and account.name else email + language = ( + account.interface_language + if isinstance(account.interface_language, str) and account.interface_language + else "en-US" + ) + payloads.append( + { + "language": language, + "to": email, + "mentioned_name": mentioned_name, + "commenter_name": commenter_name, + "app_name": app_name, + "comment_content": comment_excerpt, + "app_url": app_url, + } + ) + return payloads + + @staticmethod + def _dispatch_mention_emails(payloads: Sequence[dict[str, str]]) -> None: + """Enqueue mention notification emails.""" + for payload in payloads: + send_workflow_comment_mention_email_task.delay(**payload) + + @staticmethod + def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]: + """Get all comments for a workflow.""" + with Session(db.engine) as session: + # Get all comments with eager loading + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id) + .order_by(desc(WorkflowComment.created_at)) + ) + + comments = session.scalars(stmt).all() + + # Batch preload all Account objects to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, comments) + + return comments + + @staticmethod + def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None: + """Batch preload Account objects for comments, replies, and mentions.""" + # Collect all user IDs + user_ids: set[str] = set() + for comment in comments: + user_ids.add(comment.created_by) + if comment.resolved_by: + user_ids.add(comment.resolved_by) + user_ids.update(reply.created_by for reply in comment.replies) + user_ids.update(mention.mentioned_user_id for mention in comment.mentions) + + if not user_ids: + return + + # Batch query all accounts + accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all() + account_map = {str(account.id): account for account in accounts} + + # Cache accounts on objects + for comment in comments: + comment.cache_created_by_account(account_map.get(comment.created_by)) + comment.cache_resolved_by_account(account_map.get(comment.resolved_by) if comment.resolved_by else None) + for reply in comment.replies: + reply.cache_created_by_account(account_map.get(reply.created_by)) + for mention in comment.mentions: + mention.cache_mentioned_user_account(account_map.get(mention.mentioned_user_id)) + + @staticmethod + def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment: + """Get a specific comment.""" + + def _get_comment(session: Session) -> WorkflowComment: + stmt = ( + select(WorkflowComment) + .options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions)) + .where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Preload accounts to avoid N+1 queries + WorkflowCommentService._preload_accounts(session, [comment]) + + return comment + + if session is not None: + return _get_comment(session) + else: + with Session(db.engine, expire_on_commit=False) as session: + return _get_comment(session) + + @staticmethod + def create_comment( + tenant_id: str, + app_id: str, + created_by: str, + content: str, + position_x: float, + position_y: float, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Create a new workflow comment and send mention notification emails.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine) as session: + comment = WorkflowComment( + tenant_id=tenant_id, + app_id=app_id, + position_x=position_x, + position_y=position_y, + content=content, + created_by=created_by, + ) + + session.add(comment) + session.flush() # Get the comment ID for mentions + + # Create mentions if specified + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids or [], + session=session, + tenant_id=tenant_id, + ) + for user_id in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention, not reply mention + mentioned_user_id=user_id, + ) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=tenant_id, + app_id=app_id, + mentioner_id=created_by, + mentioned_user_ids=mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + # Return only what we need - id and created_at + return {"id": comment.id, "created_at": comment.created_at} + + @staticmethod + def update_comment( + tenant_id: str, + app_id: str, + comment_id: str, + user_id: str, + content: str, + position_x: float | None = None, + position_y: float | None = None, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Update a workflow comment and notify newly mentioned users. + + `mentioned_user_ids=None` means "leave mentions unchanged". + Passing an explicit list replaces the existing comment mentions, including clearing them with `[]`. + """ + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Get comment with validation + stmt = select(WorkflowComment).where( + WorkflowComment.id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + comment = session.scalar(stmt) + + if not comment: + raise NotFound("Comment not found") + + # Only the creator can update the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can update it") + + # Update comment fields + comment.content = content + if position_x is not None: + comment.position_x = position_x + if position_y is not None: + comment.position_y = position_y + + mention_email_payloads: list[dict[str, str]] = [] + if mentioned_user_ids is not None: + # Replace comment mentions only when the client explicitly sends the mention list. + existing_mentions = session.scalars( + select(WorkflowCommentMention).where( + WorkflowCommentMention.comment_id == comment.id, + WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions + ) + ).all() + existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions} + for mention in existing_mentions: + session.delete(mention) + + filtered_mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids, + session=session, + tenant_id=tenant_id, + ) + new_mentioned_user_ids = [ + mentioned_user_id + for mentioned_user_id in filtered_mentioned_user_ids + if mentioned_user_id not in existing_mentioned_user_ids + ] + for mentioned_user_id in filtered_mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=comment.id, + reply_id=None, # This is a comment mention + mentioned_user_id=mentioned_user_id, + ) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=tenant_id, + app_id=app_id, + mentioner_id=user_id, + mentioned_user_ids=new_mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": comment.id, "updated_at": comment.updated_at} + + @staticmethod + def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None: + """Delete a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + + # Only the creator can delete the comment + if comment.created_by != user_id: + raise Forbidden("Only the comment creator can delete it") + + # Delete associated mentions (both comment and reply mentions) + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id) + ).all() + for mention in mentions: + session.delete(mention) + + # Delete associated replies + replies = session.scalars( + select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id) + ).all() + for reply in replies: + session.delete(reply) + + session.delete(comment) + session.commit() + + @staticmethod + def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment: + """Resolve a workflow comment.""" + with Session(db.engine, expire_on_commit=False) as session: + comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session) + if comment.resolved: + return comment + + comment.resolved = True + comment.resolved_at = naive_utc_now() + comment.resolved_by = user_id + session.commit() + + return comment + + @staticmethod + def create_reply( + comment_id: str, content: str, created_by: str, mentioned_user_ids: list[str] | None = None + ) -> dict: + """Add a reply to a workflow comment and notify mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + # Check if comment exists + comment = session.get(WorkflowComment, comment_id) + if not comment: + raise NotFound("Comment not found") + + reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by) + + session.add(reply) + session.flush() # Get the reply ID for mentions + + # Create mentions if specified + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + mentioned_user_ids or [], + session=session, + tenant_id=comment.tenant_id, + ) + for user_id in mentioned_user_ids: + # Create mention linking to specific reply + mention = WorkflowCommentMention(comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id) + session.add(mention) + + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=comment.tenant_id, + app_id=comment.app_id, + mentioner_id=created_by, + mentioned_user_ids=mentioned_user_ids, + content=content, + ) + + session.commit() + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": reply.id, "created_at": reply.created_at} + + @staticmethod + def _get_reply_in_comment_scope( + *, + session: Session, + tenant_id: str, + app_id: str, + comment_id: str, + reply_id: str, + ) -> WorkflowCommentReply: + """Get a reply scoped to tenant/app/comment to prevent cross-thread mutations.""" + stmt = ( + select(WorkflowCommentReply) + .join(WorkflowComment, WorkflowComment.id == WorkflowCommentReply.comment_id) + .where( + WorkflowCommentReply.id == reply_id, + WorkflowCommentReply.comment_id == comment_id, + WorkflowComment.tenant_id == tenant_id, + WorkflowComment.app_id == app_id, + ) + .limit(1) + ) + reply = session.scalar(stmt) + if not reply: + raise NotFound("Reply not found") + return reply + + @staticmethod + def update_reply( + tenant_id: str, + app_id: str, + comment_id: str, + reply_id: str, + user_id: str, + content: str, + mentioned_user_ids: list[str] | None = None, + ) -> dict: + """Update a comment reply and notify newly mentioned users.""" + WorkflowCommentService._validate_content(content) + + with Session(db.engine, expire_on_commit=False) as session: + reply = WorkflowCommentService._get_reply_in_comment_scope( + session=session, + tenant_id=tenant_id, + app_id=app_id, + comment_id=comment_id, + reply_id=reply_id, + ) + + # Only the creator can update the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can update it") + + reply.content = content + + # Update mentions - first remove existing mentions for this reply + existing_mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id) + ).all() + existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions} + for mention in existing_mentions: + session.delete(mention) + + # Add mentions + raw_mentioned_user_ids = mentioned_user_ids or [] + comment = session.get(WorkflowComment, reply.comment_id) + mentioned_user_ids = [] + if comment: + mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids( + raw_mentioned_user_ids, + session=session, + tenant_id=comment.tenant_id, + ) + new_mentioned_user_ids = [ + user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids + ] + for user_id_str in mentioned_user_ids: + mention = WorkflowCommentMention( + comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str + ) + session.add(mention) + + mention_email_payloads: list[dict[str, str]] = [] + if comment: + mention_email_payloads = WorkflowCommentService._build_mention_email_payloads( + session=session, + tenant_id=comment.tenant_id, + app_id=comment.app_id, + mentioner_id=user_id, + mentioned_user_ids=new_mentioned_user_ids, + content=content, + ) + + session.commit() + session.refresh(reply) # Refresh to get updated timestamp + WorkflowCommentService._dispatch_mention_emails(mention_email_payloads) + + return {"id": reply.id, "updated_at": reply.updated_at} + + @staticmethod + def delete_reply(tenant_id: str, app_id: str, comment_id: str, reply_id: str, user_id: str) -> None: + """Delete a comment reply.""" + with Session(db.engine, expire_on_commit=False) as session: + reply = WorkflowCommentService._get_reply_in_comment_scope( + session=session, + tenant_id=tenant_id, + app_id=app_id, + comment_id=comment_id, + reply_id=reply_id, + ) + + # Only the creator can delete the reply + if reply.created_by != user_id: + raise Forbidden("Only the reply creator can delete it") + + # Delete associated mentions first + mentions = session.scalars( + select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id) + ).all() + for mention in mentions: + session.delete(mention) + + session.delete(reply) + session.commit() + + @staticmethod + def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment: + """Validate that a comment belongs to the specified tenant and app.""" + return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 1c1b94ae9d..a55448e352 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -3,23 +3,11 @@ import json import logging from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from enum import StrEnum -from typing import Any, ClassVar +from typing import Any, ClassVar, NotRequired, TypedDict -from graphon.enums import NodeType -from graphon.file import File -from graphon.nodes import BuiltinNodeTypes -from graphon.nodes.variable_assigner.common.helpers import get_updated_variables -from graphon.variable_loader import VariableLoader -from graphon.variables import Segment, StringSegment, VariableBase -from graphon.variables.consts import SELECTORS_LENGTH -from graphon.variables.segments import ( - ArrayFileSegment, - FileSegment, -) -from graphon.variables.types import SegmentType -from graphon.variables.utils import dumps_with_segments -from sqlalchemy import Engine, orm, select +from sqlalchemy import Engine, delete, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session, sessionmaker @@ -39,6 +27,19 @@ from core.workflow.variable_prefixes import ( from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable +from graphon.enums import NodeType +from graphon.file import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation @@ -145,7 +146,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -179,7 +180,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -190,7 +191,7 @@ class DraftVarLoader(VariableLoader): variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), - id=draft_var.id, + variable_id=draft_var.id, name=draft_var.name, description=draft_var.description, ) @@ -222,11 +223,10 @@ class WorkflowDraftVariableService: ) def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: - return ( - self._session.query(WorkflowDraftVariable) + return self._session.scalar( + select(WorkflowDraftVariable) .options(orm.selectinload(WorkflowDraftVariable.variable_file)) .where(WorkflowDraftVariable.id == variable_id) - .first() ) def get_draft_variables_by_selectors( @@ -254,20 +254,21 @@ class WorkflowDraftVariableService: # Alternatively, a `SELECT` statement could be constructed for each selector and # combined using `UNION` to fetch all rows. # Benchmarking indicates that both approaches yield comparable performance. - query = ( - self._session.query(WorkflowDraftVariable) - .options( - orm.selectinload(WorkflowDraftVariable.variable_file).selectinload( - WorkflowDraftVariableFile.upload_file + return list( + self._session.scalars( + select(WorkflowDraftVariable) + .options( + orm.selectinload(WorkflowDraftVariable.variable_file).selectinload( + WorkflowDraftVariableFile.upload_file + ) + ) + .where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.user_id == user_id, + or_(*ors), ) ) - .where( - WorkflowDraftVariable.app_id == app_id, - WorkflowDraftVariable.user_id == user_id, - or_(*ors), - ) ) - return query.all() def list_variables_without_values( self, app_id: str, page: int, limit: int, user_id: str @@ -277,18 +278,21 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.user_id == user_id, ] total = None - query = self._session.query(WorkflowDraftVariable).where(*criteria) + base_stmt = select(WorkflowDraftVariable).where(*criteria) if page == 1: - total = query.count() - variables = ( - # Do not load the `value` field - query.options( - orm.defer(WorkflowDraftVariable.value, raiseload=True), + from sqlalchemy import func as sa_func + + total = self._session.scalar(select(sa_func.count()).select_from(base_stmt.subquery())) + variables = list( + self._session.scalars( + # Do not load the `value` field + base_stmt.options( + orm.defer(WorkflowDraftVariable.value, raiseload=True), + ) + .order_by(WorkflowDraftVariable.created_at.desc()) + .limit(limit) + .offset((page - 1) * limit) ) - .order_by(WorkflowDraftVariable.created_at.desc()) - .limit(limit) - .offset((page - 1) * limit) - .all() ) return WorkflowDraftVariableList(variables=variables, total=total) @@ -299,11 +303,13 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.user_id == user_id, ] - query = self._session.query(WorkflowDraftVariable).where(*criteria) - variables = ( - query.options(orm.selectinload(WorkflowDraftVariable.variable_file)) - .order_by(WorkflowDraftVariable.created_at.desc()) - .all() + variables = list( + self._session.scalars( + select(WorkflowDraftVariable) + .options(orm.selectinload(WorkflowDraftVariable.variable_file)) + .where(*criteria) + .order_by(WorkflowDraftVariable.created_at.desc()) + ) ) return WorkflowDraftVariableList(variables=variables) @@ -326,8 +332,8 @@ class WorkflowDraftVariableService: return self._get_variable(app_id, node_id, name, user_id=user_id) def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None: - return ( - self._session.query(WorkflowDraftVariable) + return self._session.scalar( + select(WorkflowDraftVariable) .options(orm.selectinload(WorkflowDraftVariable.variable_file)) .where( WorkflowDraftVariable.app_id == app_id, @@ -335,7 +341,6 @@ class WorkflowDraftVariableService: WorkflowDraftVariable.name == name, WorkflowDraftVariable.user_id == user_id, ) - .first() ) def update_variable( @@ -488,20 +493,20 @@ class WorkflowDraftVariableService: self._session.delete(variable) def delete_user_workflow_variables(self, app_id: str, user_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.user_id == user_id, ) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def delete_app_workflow_variables(self, app_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where(WorkflowDraftVariable.app_id == app_id) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def delete_workflow_draft_variable_file(self, deletions: list[DraftVarFileDeletion]): @@ -540,14 +545,14 @@ class WorkflowDraftVariableService: return self._delete_node_variables(app_id, node_id, user_id=user_id) def _delete_node_variables(self, app_id: str, node_id: str, user_id: str): - ( - self._session.query(WorkflowDraftVariable) + self._session.execute( + delete(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.user_id == user_id, ) - .delete(synchronize_session=False) + .execution_options(synchronize_session=False) ) def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None: @@ -588,13 +593,11 @@ class WorkflowDraftVariableService: conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id) if conv_id is not None: - conversation = ( - self._session.query(Conversation) - .where( + conversation = self._session.scalar( + select(Conversation).where( Conversation.id == conv_id, Conversation.app_id == workflow.app_id, ) - .first() ) # Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB). if conversation is not None: @@ -723,8 +726,27 @@ def _batch_upsert_draft_variable( session.execute(stmt) -def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: - d: dict[str, Any] = { +class _InsertionDict(TypedDict): + id: str + app_id: str + user_id: str | None + last_edited_at: datetime | None + node_id: str + name: str + selector: str + value_type: SegmentType + value: str + node_execution_id: str | None + file_id: str | None + visible: NotRequired[bool] + editable: NotRequired[bool] + created_at: NotRequired[datetime] + updated_at: NotRequired[datetime] + description: NotRequired[str] + + +def _model_to_insertion_dict(model: WorkflowDraftVariable) -> _InsertionDict: + d: _InsertionDict = { "id": model.id, "app_id": model.app_id, "user_id": model.user_id, @@ -1045,7 +1067,7 @@ class DraftVariableSaver: filename = f"{self._generate_filename(name)}.txt" else: # For other types, store as JSON - original_content_serialized = dumps_with_segments(value_seg.value, ensure_ascii=False) + original_content_serialized = dumps_with_segments(value_seg.value) content_type = "application/json" filename = f"{self._generate_filename(name)}.json" @@ -1061,10 +1083,9 @@ class DraftVariableSaver: mimetype=content_type, user=self._user, ) - + assert self._user.current_tenant_id # Create WorkflowDraftVariableFile record variable_file = WorkflowDraftVariableFile( - id=uuidv7(), upload_file_id=upload_file.id, size=original_size, length=original_length, @@ -1073,6 +1094,7 @@ class DraftVariableSaver: tenant_id=self._user.current_tenant_id, user_id=self._user.id, ) + variable_file.id = str(uuidv7()) engine = bind = self._session.get_bind() assert isinstance(engine, Engine) with sessionmaker(bind=engine, expire_on_commit=False).begin() as session: diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 601e9261fc..94f88f8c49 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -9,15 +9,12 @@ from collections.abc import Generator, Mapping, Sequence from dataclasses import dataclass from typing import Any -from graphon.entities import WorkflowStartReason -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import desc, select from sqlalchemy.orm import Session, sessionmaker from core.app.apps.message_generator import MessageGenerator from core.app.entities.task_entities import ( + HumanInputRequiredResponse, MessageReplaceStreamResponse, NodeFinishStreamResponse, NodeStartStreamResponse, @@ -26,6 +23,14 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext +from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReasonType +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter +from models.human_input import HumanInputForm from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot @@ -59,8 +64,10 @@ def build_workflow_event_stream( tenant_id: str, app_id: str, session_maker: sessionmaker[Session], + human_input_surface: HumanInputSurface | None = None, idle_timeout: float = 300, ping_interval: float = 10.0, + close_on_pause: bool = True, ) -> Generator[Mapping[str, Any] | str, None, None]: topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id) workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) @@ -115,13 +122,15 @@ def build_workflow_event_stream( message_context=message_context, pause_entity=pause_entity, resumption_context=resumption_context, + session_maker=session_maker, + human_input_surface=human_input_surface, ) for event in snapshot_events: last_msg_time = time.time() last_ping_time = last_msg_time yield event - if _is_terminal_event(event, include_paused=True): + if _is_terminal_event(event, close_on_pause=close_on_pause): return while True: @@ -146,7 +155,7 @@ def build_workflow_event_stream( last_msg_time = time.time() last_ping_time = last_msg_time yield event - if _is_terminal_event(event, include_paused=True): + if _is_terminal_event(event, close_on_pause=close_on_pause): return finally: buffer_state.stop_event.set() @@ -207,6 +216,8 @@ def _build_snapshot_events( message_context: MessageContext | None, pause_entity: WorkflowPauseEntity | None, resumption_context: WorkflowResumptionContext | None, + session_maker: sessionmaker[Session] | None = None, + human_input_surface: HumanInputSurface | None = None, ) -> list[Mapping[str, Any]]: events: list[Mapping[str, Any]] = [] @@ -241,12 +252,24 @@ def _build_snapshot_events( events.append(node_finished) if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None: + for human_input_event in _build_human_input_required_events( + workflow_run_id=workflow_run.id, + task_id=task_id, + pause_entity=pause_entity, + session_maker=session_maker, + human_input_surface=human_input_surface, + ): + _apply_message_context(human_input_event, message_context) + events.append(human_input_event) + pause_event = _build_pause_event( workflow_run=workflow_run, workflow_run_id=workflow_run.id, task_id=task_id, pause_entity=pause_entity, resumption_context=resumption_context, + session_maker=session_maker, + human_input_surface=human_input_surface, ) if pause_event is not None: _apply_message_context(pause_event, message_context) @@ -314,6 +337,97 @@ def _build_node_started_event( return response.to_ignore_detail_dict() +def _build_human_input_required_events( + *, + workflow_run_id: str, + task_id: str, + pause_entity: WorkflowPauseEntity, + session_maker: sessionmaker[Session] | None, + human_input_surface: HumanInputSurface | None, +) -> list[dict[str, Any]]: + reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()] + human_input_form_ids = [ + form_id + for reason in reasons + if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED + for form_id in [reason.get("form_id")] + if isinstance(form_id, str) + ] + + expiration_times_by_form_id: dict[str, int] = {} + display_in_ui_by_form_id: dict[str, bool] = {} + form_tokens_by_form_id: dict[str, str] = {} + if human_input_form_ids and session_maker is not None: + stmt = select(HumanInputForm.id, HumanInputForm.expiration_time, HumanInputForm.form_definition).where( + HumanInputForm.id.in_(human_input_form_ids) + ) + with session_maker() as session: + for form_id, expiration_time, form_definition in session.execute(stmt): + expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp()) + try: + definition_payload = json.loads(form_definition) if form_definition else {} + except (TypeError, json.JSONDecodeError): + definition_payload = {} + display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) + form_tokens_by_form_id = load_form_tokens_by_form_id( + human_input_form_ids, + session=session, + surface=human_input_surface, + ) + + events: list[dict[str, Any]] = [] + for reason in reasons: + if reason.get("TYPE") != PauseReasonType.HUMAN_INPUT_REQUIRED: + continue + + form_id_raw = reason.get("form_id") + node_id_raw = reason.get("node_id") + node_title_raw = reason.get("node_title") + form_content_raw = reason.get("form_content") + if not isinstance(form_id_raw, str): + continue + if not isinstance(node_id_raw, str): + continue + if not isinstance(node_title_raw, str): + continue + if not isinstance(form_content_raw, str): + continue + form_id = form_id_raw + node_id = node_id_raw + node_title = node_title_raw + form_content = form_content_raw + + inputs = reason.get("inputs") + actions = reason.get("actions") + resolved_default_values = reason.get("resolved_default_values") + + expiration_time = expiration_times_by_form_id.get(form_id) + if expiration_time is None: + continue + + response = HumanInputRequiredResponse( + task_id=task_id, + workflow_run_id=workflow_run_id, + data=HumanInputRequiredResponse.Data( + form_id=form_id, + node_id=node_id, + node_title=node_title, + form_content=form_content, + inputs=inputs if isinstance(inputs, list) else [], + actions=actions if isinstance(actions, list) else [], + display_in_ui=display_in_ui_by_form_id.get(form_id, False), + form_token=form_tokens_by_form_id.get(form_id), + resolved_default_values=(resolved_default_values if isinstance(resolved_default_values, dict) else {}), + expiration_time=expiration_time, + ), + ) + payload = response.model_dump(mode="json") + payload["event"] = response.event.value + events.append(payload) + + return events + + def _build_node_finished_event( *, workflow_run_id: str, @@ -356,6 +470,8 @@ def _build_pause_event( task_id: str, pause_entity: WorkflowPauseEntity, resumption_context: WorkflowResumptionContext | None, + session_maker: sessionmaker[Session] | None, + human_input_surface: HumanInputSurface | None = None, ) -> dict[str, Any] | None: paused_nodes: list[str] = [] outputs: dict[str, Any] = {} @@ -365,6 +481,36 @@ def _build_pause_event( outputs = dict(WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {})) reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()] + human_input_form_ids = [ + form_id + for reason in reasons + if reason.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED + for form_id in [reason.get("form_id")] + if isinstance(form_id, str) + ] + form_tokens_by_form_id: dict[str, str] = {} + expiration_times_by_form_id: dict[str, int] = {} + if human_input_form_ids and session_maker is not None: + with session_maker() as session: + form_tokens_by_form_id = load_form_tokens_by_form_id( + human_input_form_ids, + session=session, + surface=human_input_surface, + ) + stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( + HumanInputForm.id.in_(human_input_form_ids) + ) + for row in session.execute(stmt): + form_id, expiration_time, *_rest = row + expiration_times_by_form_id[str(form_id)] = int(expiration_time.timestamp()) + # Reconnect paths must preserve the same pause-reason contract as live streams; + # otherwise clients see schema drift after resume. + reasons = enrich_human_input_pause_reasons( + reasons, + form_tokens_by_form_id=form_tokens_by_form_id, + expiration_times_by_form_id=expiration_times_by_form_id, + ) + response = WorkflowPauseStreamResponse( task_id=task_id, workflow_run_id=workflow_run_id, @@ -449,12 +595,19 @@ def _parse_event_message(message: bytes) -> Mapping[str, Any] | None: return event -def _is_terminal_event(event: Mapping[str, Any] | str, include_paused=False) -> bool: +def _is_terminal_event( + event: Mapping[str, Any] | str, + close_on_pause: bool = True, + *, + include_paused: bool | None = None, +) -> bool: + if include_paused is not None: + close_on_pause = include_paused if not isinstance(event, Mapping): return False event_type = event.get("event") if event_type == StreamEvent.WORKFLOW_FINISHED.value: return True - if include_paused: + if close_on_pause: return event_type == StreamEvent.WORKFLOW_PAUSED.value return False diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index b903d8df5f..29b9e72a00 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,5 +1,6 @@ import threading from collections.abc import Sequence +from typing import TypedDict from sqlalchemy import Engine from sqlalchemy.orm import sessionmaker @@ -19,6 +20,14 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory +class WorkflowRunListArgs(TypedDict, total=False): + """Expected shape of the args dict passed to workflow run pagination methods.""" + + limit: int + last_id: str + status: str + + class WorkflowRunService: _session_factory: sessionmaker _workflow_run_repo: APIWorkflowRunRepository @@ -37,7 +46,10 @@ class WorkflowRunService: self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_factory) def get_paginate_advanced_chat_workflow_runs( - self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING + self, + app_model: App, + args: WorkflowRunListArgs, + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, ) -> InfiniteScrollPagination: """ Get advanced chat app workflow run list @@ -73,7 +85,10 @@ class WorkflowRunService: return pagination def get_paginate_workflow_runs( - self, app_model: App, args: dict, triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING + self, + app_model: App, + args: WorkflowRunListArgs, + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, ) -> InfiniteScrollPagination: """ Get workflow run list diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index c28704e83b..f97b85dc2b 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,6 +5,40 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast +from sqlalchemy import exists, select +from sqlalchemy.orm import Session, sessionmaker + +from configs import dify_config +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.entities import PluginCredentialType +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager +from core.repositories import DifyCoreRepositoryFactory +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.trigger.constants import is_trigger_node_type +from core.workflow.human_input_adapter import ( + DeliveryChannelConfig, + adapt_human_input_node_data_for_graph, + parse_human_input_delivery_methods, +) +from core.workflow.node_factory import ( + LATEST_VERSION, + DifyGraphInitContext, + get_node_type_classes_mapping, + is_start_node_type, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace +from enums.cloud_plan import CloudPlan +from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated +from extensions.ext_database import db +from extensions.ext_storage import storage +from factories.file_factory import build_from_mapping, build_from_mappings from graphon.entities import WorkflowNodeExecution from graphon.entities.graph_config import NodeConfigDict from graphon.entities.pause_reason import HumanInputRequired @@ -30,40 +64,6 @@ from graphon.variable_loader import load_into_variable_pool from graphon.variables import VariableBase from graphon.variables.input_entities import VariableEntityType from graphon.variables.variables import Variable -from sqlalchemy import exists, select -from sqlalchemy.orm import Session, sessionmaker - -from configs import dify_config -from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager -from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from core.app.file_access import DatabaseFileAccessController -from core.entities import PluginCredentialType -from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager -from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.trigger.constants import is_trigger_node_type -from core.workflow.human_input_compat import ( - DeliveryChannelConfig, - normalize_human_input_node_data_for_graph, - parse_human_input_delivery_methods, -) -from core.workflow.node_factory import ( - LATEST_VERSION, - DifyGraphInitContext, - get_node_type_classes_mapping, - is_start_node_type, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool -from core.workflow.workflow_entry import WorkflowEntry -from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace -from enums.cloud_plan import CloudPlan -from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated -from extensions.ext_database import db -from extensions.ext_storage import storage -from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account from models.human_input import HumanInputFormRecipient, RecipientType @@ -156,11 +156,18 @@ class WorkflowService: # return draft workflow return workflow - def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None: + def get_published_workflow_by_id( + self, app_model: App, workflow_id: str, session: Session | None = None + ) -> Workflow | None: """ fetch published workflow by workflow_id + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ - workflow = db.session.scalar( + bind = session if session is not None else db.session + workflow = bind.scalar( select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, @@ -178,16 +185,20 @@ class WorkflowService: ) return workflow - def get_published_workflow(self, app_model: App) -> Workflow | None: + def get_published_workflow(self, app_model: App, session: Session | None = None) -> Workflow | None: """ Get published workflow + + When ``session`` is provided, reuse it so callers that already hold a + Session avoid checking out an extra request-scoped ``db.session`` + connection. Falls back to ``db.session`` for backward compatibility. """ if not app_model.workflow_id: return None - # fetch published workflow by workflow_id - workflow = db.session.scalar( + bind = session if session is not None else db.session + workflow = bind.scalar( select(Workflow) .where( Workflow.tenant_id == app_model.tenant_id, @@ -199,6 +210,16 @@ class WorkflowService: return workflow + def get_accessible_app_ids(self, app_ids: Sequence[str], tenant_id: str) -> set[str]: + """ + Return app IDs that belong to the given tenant. + """ + if not app_ids: + return set() + + stmt = select(App.id).where(App.id.in_(app_ids), App.tenant_id == tenant_id) + return {str(app_id) for app_id in db.session.scalars(stmt).all()} + def get_all_published_workflow( self, *, @@ -241,8 +262,8 @@ class WorkflowService: self, *, app_model: App, - graph: dict, - features: dict, + graph: dict[str, Any], + features: dict[str, Any], unique_hash: str | None, account: Account, environment_variables: Sequence[VariableBase], @@ -296,6 +317,78 @@ class WorkflowService: # return draft workflow return workflow + def update_draft_workflow_environment_variables( + self, + *, + app_model: App, + environment_variables: Sequence[VariableBase], + account: Account, + ): + """ + Update draft workflow environment variables + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + workflow.environment_variables = environment_variables + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + + def update_draft_workflow_conversation_variables( + self, + *, + app_model: App, + conversation_variables: Sequence[VariableBase], + account: Account, + ): + """ + Update draft workflow conversation variables + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + workflow.conversation_variables = conversation_variables + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + + def update_draft_workflow_features( + self, + *, + app_model: App, + features: dict, + account: Account, + ): + """ + Update draft workflow features + """ + # fetch draft workflow by app_model + workflow = self.get_draft_workflow(app_model=app_model) + + if not workflow: + raise ValueError("No draft workflow found.") + + # validate features structure + self.validate_features_structure(app_model=app_model, features=features) + + workflow.features = json.dumps(features) + workflow.updated_by = account.id + workflow.updated_at = naive_utc_now() + + # commit db session changes + db.session.commit() + def restore_published_workflow_to_draft( self, *, @@ -576,7 +669,7 @@ class WorkflowService: except Exception as e: raise ValueError(f"Failed to validate default credential for tool provider {provider}: {str(e)}") - def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict, node_id: str) -> None: + def _validate_load_balancing_credentials(self, workflow: Workflow, node_data: dict[str, Any], node_id: str) -> None: """ Validate load balancing credentials for a workflow node. @@ -709,7 +802,7 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type_enum = NodeType(node_type) + node_type_enum: NodeType = node_type node_mapping = get_node_type_classes_mapping() # return default block config @@ -1014,7 +1107,7 @@ class WorkflowService: raise ValueError("Node type must be human-input.") node_data = HumanInputNodeData.model_validate( - normalize_human_input_node_data_for_graph(node_config["data"]), + adapt_human_input_node_data_for_graph(node_config["data"]), from_attributes=True, ) delivery_method = self._resolve_human_input_delivery_method( @@ -1155,9 +1248,10 @@ class WorkflowService: variable_pool=variable_pool, start_at=time.perf_counter(), ) + node_data = HumanInputNode.validate_node_data(adapt_human_input_node_data_for_graph(node_config["data"])) node = HumanInputNode( - id=node_config["id"], - config=node_config, + node_id=node_config["id"], + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, runtime=DifyHumanInputNodeRuntime(run_context), @@ -1214,7 +1308,7 @@ class WorkflowService: return variable_pool def run_free_workflow_node( - self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] + self, node_data: dict[str, Any], tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] ) -> WorkflowNodeExecution: """ Run free workflow node @@ -1361,7 +1455,7 @@ class WorkflowService: node_execution.status = WorkflowNodeExecutionStatus.FAILED node_execution.error = error - def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: + def convert_to_workflow(self, app_model: App, account: Account, args: dict[str, Any]) -> App: """ Basic mode of chatbot app(expert mode) to workflow Completion App to Workflow App @@ -1421,7 +1515,7 @@ class WorkflowService: if node_type == BuiltinNodeTypes.HUMAN_INPUT: self._validate_human_input_node_data(node_data) - def validate_features_structure(self, app_model: App, features: dict): + def validate_features_structure(self, app_model: App, features: dict[str, Any]): match app_model.mode: case AppMode.ADVANCED_CHAT: return AdvancedChatAppConfigManager.config_validate( @@ -1434,7 +1528,7 @@ class WorkflowService: case _: raise ValueError(f"Invalid app mode: {app_model.mode}") - def _validate_human_input_node_data(self, node_data: dict) -> None: + def _validate_human_input_node_data(self, node_data: dict[str, Any]) -> None: """ Validate HumanInput node data format. @@ -1447,12 +1541,12 @@ class WorkflowService: from graphon.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) + HumanInputNodeData.model_validate(adapt_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") def update_workflow( - self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict + self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict[str, Any] ) -> Workflow | None: """ Update workflow attributes @@ -1512,14 +1606,12 @@ class WorkflowService: # Don't use workflow.tool_published as it's not accurate for specific workflow versions # Check if there's a tool provider using this specific workflow version - tool_provider = ( - session.query(WorkflowToolProvider) - .where( + tool_provider = session.scalar( + select(WorkflowToolProvider).where( WorkflowToolProvider.tenant_id == workflow.tenant_id, WorkflowToolProvider.app_id == workflow.app_id, WorkflowToolProvider.version == workflow.version, ) - .first() ) if tool_provider: diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 8f2f5f261e..5ceeb302c8 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -7,7 +7,6 @@ from typing import Annotated, Any from celery import shared_task from flask import current_app, json -from graphon.runtime import GraphRuntimeState from pydantic import BaseModel, Discriminator, Field, Tag from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -23,6 +22,7 @@ from core.app.entities.app_invoke_entities import ( from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from libs.flask_utils import set_login_user from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -399,6 +399,8 @@ def _resume_advanced_chat( workflow_run_id: str, workflow_run: WorkflowRun, ) -> None: + resumed_generate_entity = generate_entity.model_copy(update={"stream": True}) + try: triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) except ValueError: @@ -426,7 +428,7 @@ def _resume_advanced_chat( user=user, conversation=conversation, message=message, - application_generate_entity=generate_entity, + application_generate_entity=resumed_generate_entity, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, graph_runtime_state=graph_runtime_state, @@ -436,9 +438,8 @@ def _resume_advanced_chat( logger.exception("Failed to resume chatflow execution for workflow run %s", workflow_run_id) raise - if generate_entity.stream: - assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT) + assert isinstance(response, Generator) + _publish_streaming_response(response, workflow_run_id, AppMode.ADVANCED_CHAT) def _resume_workflow( @@ -455,6 +456,8 @@ def _resume_workflow( workflow_run_repo, pause_entity, ) -> None: + resumed_generate_entity = generate_entity.model_copy(update={"stream": True}) + try: triggered_from = WorkflowRunTriggeredFrom(workflow_run.triggered_from) except ValueError: @@ -480,7 +483,7 @@ def _resume_workflow( app_model=app_model, workflow=workflow, user=user, - application_generate_entity=generate_entity, + application_generate_entity=resumed_generate_entity, graph_runtime_state=graph_runtime_state, workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, @@ -490,11 +493,18 @@ def _resume_workflow( logger.exception("Failed to resume workflow execution for workflow run %s", workflow_run_id) raise - if generate_entity.stream: - assert isinstance(response, Generator) - _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW) + assert isinstance(response, Generator) + _publish_streaming_response(response, workflow_run_id, AppMode.WORKFLOW) - workflow_run_repo.delete_workflow_pause(pause_entity) + try: + workflow_run_repo.delete_workflow_pause(pause_entity) + except Exception as exc: + if exc.__class__.__name__ != "_WorkflowRunError" or "WorkflowPause not found" not in str(exc): + raise + logger.info( + "Skipped deleting workflow pause %s after resume because it was already replaced or removed", + pause_entity.id, + ) @shared_task(queue=WORKFLOW_BASED_APP_EXECUTION_QUEUE, name="resume_app_execution") diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index 0a73c91279..5809268992 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -7,15 +7,15 @@ with appropriate retry policies and error handling. import logging from datetime import UTC, datetime -from typing import Any +from typing import Any, NotRequired from celery import shared_task -from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker +from typing_extensions import TypedDict from configs import dify_config -from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator +from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.app.layers.timeslice_layer import TimeSliceLayer @@ -23,6 +23,7 @@ from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory from extensions.ext_database import db +from graphon.runtime import GraphRuntimeState from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus from models.model import App, EndUser, Tenant @@ -42,6 +43,13 @@ from tasks.workflow_cfs_scheduler.entities import AsyncWorkflowQueue, AsyncWorkf logger = logging.getLogger(__name__) +class WorkflowGeneratorArgsDict(TypedDict): + inputs: dict[str, Any] + files: list[Any] + _skip_prepare_user_inputs: bool + workflow_id: NotRequired[str] + + @shared_task(queue=AsyncWorkflowQueue.PROFESSIONAL_QUEUE) def execute_workflow_professional(task_data_dict: dict[str, Any]): """Execute workflow for professional tier with highest priority""" @@ -90,15 +98,13 @@ def execute_workflow_sandbox(task_data_dict: dict[str, Any]): ) -def _build_generator_args(trigger_data: TriggerData) -> dict[str, Any]: +def _build_generator_args(trigger_data: TriggerData) -> WorkflowGeneratorArgsDict: """Build args passed into WorkflowAppGenerator.generate for Celery executions.""" - - args: dict[str, Any] = { + return { "inputs": dict(trigger_data.inputs), "files": list(trigger_data.files), - SKIP_PREPARE_USER_INPUTS_KEY: True, + "_skip_prepare_user_inputs": True, } - return args def _execute_workflow_common( @@ -156,7 +162,12 @@ def _execute_workflow_common( state_owner_user_id=workflow.created_by, ) - # Execute the workflow with the trigger type + # NOTE (hj24) + # Release the transaction before the blocking generate() call, + # otherwise the connection stays "idle in transaction" for hours. + session.commit() + # NOTE END + generator.generate( app_model=app_model, workflow=workflow, diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 77feea47a2..beb23d8354 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -3,11 +3,11 @@ import tempfile import time import uuid from pathlib import Path +from typing import Any import click import pandas as pd from celery import shared_task -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func, select from core.db.session_factory import session_factory @@ -15,6 +15,7 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from graphon.model_runtime.entities.model_entities import ModelType from libs import helper from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -51,8 +52,8 @@ def batch_create_segment_to_index_task( # Initialize variables with default values upload_file_key: str | None = None - dataset_config: dict | None = None - document_config: dict | None = None + dataset_config: dict[str, Any] | None = None + document_config: dict[str, Any] | None = None with session_factory.create_session() as session: try: diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index a657cd553a..c8d0e31c06 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -61,13 +61,31 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i # check segment is exist if index_node_ids: - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - with session_factory.create_session() as session: - dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) - if dataset: - index_processor.clean( - dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True - ) + # Wrap vector / keyword index cleanup in try/except so that a transient + # failure here (e.g. billing API hiccup propagated via FeatureService when + # ModelManager is initialized inside ``Vector(dataset)``) does not abort + # the entire task and leave document_segments / child_chunks / image_files + # / metadata bindings stranded in PG. Mirrors the pattern already used in + # ``clean_dataset_task`` so the document row's hard delete (already + # committed by the caller) does not produce orphan PG rows just because + # the vector backend or one of its transitive dependencies was unhappy. + try: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + with session_factory.create_session() as session: + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) + if dataset: + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) + except Exception: + logger.exception( + "Failed to clean vector / keyword index in clean_document_task, " + "document_id=%s, dataset_id=%s, index_node_ids_count=%d. " + "Continuing with PG / storage cleanup; vector orphans can be reaped later.", + document_id, + dataset_id, + len(index_node_ids), + ) total_image_files = [] with session_factory.create_session() as session, session.begin(): diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index e3be24ac74..017d60efac 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -40,12 +40,29 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() total_index_node_ids.extend([segment.index_node_id for segment in segments]) - with session_factory.create_session() as session: - dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) - if dataset: - index_processor.clean( - dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True - ) + # Wrap vector / keyword index cleanup in try/except so that a transient + # failure here (e.g. billing API hiccup propagated via FeatureService when + # ``ModelManager`` is initialized inside ``Vector(dataset)``) does not abort + # the task and leave the already-deleted documents' segments stranded in PG. + # The Document rows are hard-deleted in the previous session block, so any + # exception escaping this task would produce orphans that no later request + # can reference back. Mirrors the pattern in ``clean_dataset_task``. + try: + with session_factory.create_session() as session: + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) + if dataset: + index_processor.clean( + dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) + except Exception: + logger.exception( + "Failed to clean vector / keyword index in clean_notion_document_task, " + "dataset_id=%s, document_ids=%s, index_node_ids_count=%d. " + "Continuing with segment deletion; vector orphans can be reaped later.", + dataset_id, + document_ids, + len(total_index_node_ids), + ) with session_factory.create_session() as session, session.begin(): segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 23a80fa106..31dad7937c 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -5,6 +5,7 @@ from typing import Any, Protocol import click from celery import current_app, shared_task +from sqlalchemy import select from configs import dify_config from core.db.session_factory import session_factory @@ -53,11 +54,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): Usage: _document_indexing(dataset_id, document_ids) """ - documents = [] start_at = time.perf_counter() with session_factory.create_session() as session: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow")) return @@ -79,8 +79,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): ) except Exception as e: for document_id in document_ids: - document = ( - session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) if document: document.indexing_status = IndexingStatus.ERROR @@ -92,8 +92,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # Phase 1: Update status to parsing (short transaction) with session_factory.create_session() as session, session.begin(): - documents = ( - session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all() + documents: list[Document] = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) for document in documents: @@ -122,7 +124,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # Trigger summary index generation for completed documents if enabled # Only generate for high_quality indexing technique and when summary_index_setting is enabled # Re-query dataset to get latest summary_index_setting (in case it was updated) - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.warning("Dataset %s not found after indexing", dataset_id) return @@ -134,10 +136,10 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.expire_all() # Check each document's indexing status and trigger summary generation if completed - documents = ( - session.query(Document) - .where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) - .all() + documents = list( + session.scalars( + select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + ).all() ) for document in documents: diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index ca73b4d374..fd743205a1 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -2,8 +2,6 @@ import logging from datetime import timedelta from celery import shared_task -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import or_, select from sqlalchemy.orm import sessionmaker @@ -11,6 +9,8 @@ from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository from extensions.ext_database import db from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from models.human_input import HumanInputForm from models.workflow import WorkflowPause, WorkflowRun diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index a316eec7b9..2a60be7762 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -6,15 +6,15 @@ from typing import Any import click from celery import shared_task -from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail +from graphon.runtime import GraphRuntimeState, VariablePool from models.human_input import ( DeliveryMethodType, HumanInputDelivery, diff --git a/api/tasks/mail_workflow_comment_task.py b/api/tasks/mail_workflow_comment_task.py new file mode 100644 index 0000000000..36d51f0514 --- /dev/null +++ b/api/tasks/mail_workflow_comment_task.py @@ -0,0 +1,65 @@ +import logging +import time + +import click +from celery import shared_task + +from extensions.ext_mail import mail +from libs.email_i18n import EmailType, get_email_i18n_service + +logger = logging.getLogger(__name__) + + +@shared_task(queue="mail") +def send_workflow_comment_mention_email_task( + language: str, + to: str, + mentioned_name: str, + commenter_name: str, + app_name: str, + comment_content: str, + app_url: str, +): + """ + Send workflow comment mention email with internationalization support. + + Args: + language: Language code for email localization + to: Recipient email address + mentioned_name: Name of the mentioned user + commenter_name: Name of the comment author + app_name: Name of the app where the comment was made + comment_content: Comment content excerpt + app_url: Link to the app workflow page + """ + if not mail.is_inited(): + return + + logger.info(click.style(f"Start workflow comment mention mail to {to}", fg="green")) + start_at = time.perf_counter() + + try: + email_service = get_email_i18n_service() + email_service.send_email( + email_type=EmailType.WORKFLOW_COMMENT_MENTION, + language_code=language, + to=to, + template_context={ + "to": to, + "mentioned_name": mentioned_name, + "commenter_name": commenter_name, + "app_name": app_name, + "comment_content": comment_content, + "app_url": app_url, + }, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Send workflow comment mention mail to {to} succeeded: latency: {end_at - start_at}", + fg="green", + ) + ) + except Exception: + logger.exception("workflow comment mention email to %s failed", to) diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index 5d201bd801..48d1774ce3 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -11,6 +11,7 @@ from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.impl.plugin import PluginInstaller from extensions.ext_redis import redis_client from models.account import TenantPluginAutoUpgradeStrategy +from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) @@ -171,14 +172,13 @@ def process_tenant_plugin_autoupgrade_check_task( fg="green", ) ) - _ = manager.upgrade_plugin( + # Use the service that downloads and uploads the package to the daemon + # first; calling manager.upgrade_plugin directly skips that step and the + # daemon fails because the package never reaches its local bucket. + _ = PluginService.upgrade_plugin_with_marketplace( tenant_id, original_unique_identifier, new_unique_identifier, - PluginInstallationSource.Marketplace, - { - "plugin_unique_identifier": new_unique_identifier, - }, ) except Exception as e: click.echo(click.style(f"Error when upgrading plugin: {e}", fg="red")) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index b1840662ff..5f1f0952af 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -6,7 +6,7 @@ from typing import Any, cast import click import sqlalchemy as sa from celery import shared_task -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.engine import CursorResult from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker @@ -99,7 +99,11 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): def _delete_app_model_configs(tenant_id: str, app_id: str): def del_model_config(session, model_config_id: str): - session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False) + session.execute( + delete(AppModelConfig) + .where(AppModelConfig.id == model_config_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from app_model_configs where app_id=:app_id limit 1000""", @@ -111,7 +115,7 @@ def _delete_app_model_configs(tenant_id: str, app_id: str): def _delete_app_site(tenant_id: str, app_id: str): def del_site(session, site_id: str): - session.query(Site).where(Site.id == site_id).delete(synchronize_session=False) + session.execute(delete(Site).where(Site.id == site_id).execution_options(synchronize_session=False)) _delete_records( """select id from sites where app_id=:app_id limit 1000""", @@ -123,7 +127,9 @@ def _delete_app_site(tenant_id: str, app_id: str): def _delete_app_mcp_servers(tenant_id: str, app_id: str): def del_mcp_server(session, mcp_server_id: str): - session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False) + session.execute( + delete(AppMCPServer).where(AppMCPServer.id == mcp_server_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from app_mcp_servers where app_id=:app_id limit 1000""", @@ -136,12 +142,14 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str): def _delete_app_api_tokens(tenant_id: str, app_id: str): def del_api_token(session, api_token_id: str): # Fetch token details for cache invalidation - token_obj = session.query(ApiToken).where(ApiToken.id == api_token_id).first() + token_obj = session.scalar(select(ApiToken).where(ApiToken.id == api_token_id).limit(1)) if token_obj: # Invalidate cache before deletion ApiTokenCache.delete(token_obj.token, token_obj.type) - session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False) + session.execute( + delete(ApiToken).where(ApiToken.id == api_token_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from api_tokens where app_id=:app_id limit 1000""", @@ -153,7 +161,9 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str): def _delete_installed_apps(tenant_id: str, app_id: str): def del_installed_app(session, installed_app_id: str): - session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False) + session.execute( + delete(InstalledApp).where(InstalledApp.id == installed_app_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -165,7 +175,11 @@ def _delete_installed_apps(tenant_id: str, app_id: str): def _delete_recommended_apps(tenant_id: str, app_id: str): def del_recommended_app(session, recommended_app_id: str): - session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False) + session.execute( + delete(RecommendedApp) + .where(RecommendedApp.id == recommended_app_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from recommended_apps where app_id=:app_id limit 1000""", @@ -177,8 +191,10 @@ def _delete_recommended_apps(tenant_id: str, app_id: str): def _delete_app_annotation_data(tenant_id: str, app_id: str): def del_annotation_hit_history(session, annotation_hit_history_id: str): - session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete( - synchronize_session=False + session.execute( + delete(AppAnnotationHitHistory) + .where(AppAnnotationHitHistory.id == annotation_hit_history_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -189,8 +205,10 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): ) def del_annotation_setting(session, annotation_setting_id: str): - session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete( - synchronize_session=False + session.execute( + delete(AppAnnotationSetting) + .where(AppAnnotationSetting.id == annotation_setting_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -203,7 +221,11 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str): def _delete_app_dataset_joins(tenant_id: str, app_id: str): def del_dataset_join(session, dataset_join_id: str): - session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False) + session.execute( + delete(AppDatasetJoin) + .where(AppDatasetJoin.id == dataset_join_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from app_dataset_joins where app_id=:app_id limit 1000""", @@ -215,7 +237,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str): def _delete_app_workflows(tenant_id: str, app_id: str): def del_workflow(session, workflow_id: str): - session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False) + session.execute(delete(Workflow).where(Workflow.id == workflow_id).execution_options(synchronize_session=False)) _delete_records( """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -255,7 +277,11 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(session, workflow_app_log_id: str): - session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowAppLog) + .where(WorkflowAppLog.id == workflow_app_log_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -267,8 +293,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): def del_workflow_archive_log(session, workflow_archive_log_id: str): - session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowArchiveLog) + .where(WorkflowArchiveLog.id == workflow_archive_log_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -306,10 +334,14 @@ def _delete_archived_workflow_run_files(tenant_id: str, app_id: str): def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(session, conversation_id: str): - session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False + session.execute( + delete(PinnedConversation) + .where(PinnedConversation.conversation_id == conversation_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(Conversation).where(Conversation.id == conversation_id).execution_options(synchronize_session=False) ) - session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", @@ -329,17 +361,35 @@ def _delete_conversation_variables(*, app_id: str): def _delete_app_messages(tenant_id: str, app_id: str): def del_message(session, message_id: str): - session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False) - session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete( - synchronize_session=False + session.execute( + delete(MessageFeedback) + .where(MessageFeedback.message_id == message_id) + .execution_options(synchronize_session=False) ) - session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False) - session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete( - synchronize_session=False + session.execute( + delete(MessageAnnotation) + .where(MessageAnnotation.message_id == message_id) + .execution_options(synchronize_session=False) ) - session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False) - session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False) - session.query(Message).where(Message.id == message_id).delete() + session.execute( + delete(MessageChain) + .where(MessageChain.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(MessageAgentThought) + .where(MessageAgentThought.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute( + delete(MessageFile).where(MessageFile.message_id == message_id).execution_options(synchronize_session=False) + ) + session.execute( + delete(SavedMessage) + .where(SavedMessage.message_id == message_id) + .execution_options(synchronize_session=False) + ) + session.execute(delete(Message).where(Message.id == message_id).execution_options(synchronize_session=False)) _delete_records( """select id from messages where app_id=:app_id limit 1000""", @@ -351,8 +401,10 @@ def _delete_app_messages(tenant_id: str, app_id: str): def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def del_tool_provider(session, tool_provider_id: str): - session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowToolProvider) + .where(WorkflowToolProvider.id == tool_provider_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -365,7 +417,9 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def _delete_app_tag_bindings(tenant_id: str, app_id: str): def del_tag_binding(session, tag_binding_id: str): - session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False) + session.execute( + delete(TagBinding).where(TagBinding.id == tag_binding_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", @@ -377,7 +431,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str): def _delete_end_users(tenant_id: str, app_id: str): def del_end_user(session, end_user_id: str): - session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False) + session.execute(delete(EndUser).where(EndUser.id == end_user_id).execution_options(synchronize_session=False)) _delete_records( """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -389,7 +443,11 @@ def _delete_end_users(tenant_id: str, app_id: str): def _delete_trace_app_configs(tenant_id: str, app_id: str): def del_trace_app_config(session, trace_app_config_id: str): - session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False) + session.execute( + delete(TraceAppConfig) + .where(TraceAppConfig.id == trace_app_config_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from trace_app_config where app_id=:app_id limit 1000""", @@ -545,7 +603,9 @@ def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int: def _delete_app_triggers(tenant_id: str, app_id: str): def del_app_trigger(session, trigger_id: str): - session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False) + session.execute( + delete(AppTrigger).where(AppTrigger.id == trigger_id).execution_options(synchronize_session=False) + ) _delete_records( """select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -557,8 +617,10 @@ def _delete_app_triggers(tenant_id: str, app_id: str): def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def del_plugin_trigger(session, trigger_id: str): - session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowPluginTrigger) + .where(WorkflowPluginTrigger.id == trigger_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -571,8 +633,10 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str): def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def del_webhook_trigger(session, trigger_id: str): - session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete( - synchronize_session=False + session.execute( + delete(WorkflowWebhookTrigger) + .where(WorkflowWebhookTrigger.id == trigger_id) + .execution_options(synchronize_session=False) ) _delete_records( @@ -585,7 +649,11 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str): def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def del_schedule_plan(session, plan_id: str): - session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowSchedulePlan) + .where(WorkflowSchedulePlan.id == plan_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -597,7 +665,11 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str): def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): def del_trigger_log(session, log_id: str): - session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False) + session.execute( + delete(WorkflowTriggerLog) + .where(WorkflowTriggerLog.id == log_id) + .execution_options(synchronize_session=False) + ) _delete_records( """select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", @@ -607,7 +679,7 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str): ) -def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None: +def _delete_records(query_sql: str, params: dict[str, Any], delete_func: Callable, name: str) -> None: while True: with session_factory.create_session() as session: rs = session.execute(sa.text(query_sql), params) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index 56626e372e..8505375b6a 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,7 +12,6 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -28,7 +27,8 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType +from graphon.enums import WorkflowExecutionStatus from models.enums import ( AppTriggerType, CreatorUserRole, @@ -42,6 +42,7 @@ from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, from services.async_workflow_service import AsyncWorkflowService from services.end_user_service import EndUserService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.trigger_provider_service import TriggerProviderService from services.trigger.trigger_request_service import TriggerHttpRequestCachingService @@ -258,59 +259,58 @@ def dispatch_triggered_workflow( tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id) ) trigger_entity: TriggerProviderEntity = provider_controller.entity + + # Ensure expire_on_commit is set to False to remain workflows available with session_factory.create_session() as session: workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers) - end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( - type=InvokeFrom.TRIGGER, - tenant_id=subscription.tenant_id, - app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers], - user_id=user_id, - ) - for plugin_trigger in subscribers: - # Get workflow from mapping - workflow: Workflow | None = workflows.get(plugin_trigger.app_id) - if not workflow: - logger.error( - "Workflow not found for app %s", - plugin_trigger.app_id, - ) - continue + end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch( + type=InvokeFrom.TRIGGER, + tenant_id=subscription.tenant_id, + app_ids=[plugin_trigger.app_id for plugin_trigger in subscribers], + user_id=user_id, + ) - # Find the trigger node in the workflow - event_node = None - for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): - if node_id == plugin_trigger.node_id: - event_node = node_config - break - - if not event_node: - logger.error("Trigger event node not found for app %s", plugin_trigger.app_id) - continue - - # invoke trigger - trigger_metadata = PluginTriggerMetadata( - plugin_unique_identifier=provider_controller.plugin_unique_identifier or "", - endpoint_id=subscription.endpoint_id, - provider_id=subscription.provider_id, - event_name=event_name, - icon_filename=trigger_entity.identity.icon or "", - icon_dark_filename=trigger_entity.identity.icon_dark or "", + for plugin_trigger in subscribers: + workflow: Workflow | None = workflows.get(plugin_trigger.app_id) + if not workflow: + logger.error( + "Workflow not found for app %s", + plugin_trigger.app_id, ) + continue - # consume quota before invoking trigger - quota_charge = unlimited() - try: - quota_charge = QuotaType.TRIGGER.consume(subscription.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) - logger.info( - "Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id - ) - return 0 + event_node = None + for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): + if node_id == plugin_trigger.node_id: + event_node = node_config + break - node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node) - invoke_response: TriggerInvokeEventResponse | None = None + if not event_node: + logger.error("Trigger event node not found for app %s", plugin_trigger.app_id) + continue + + trigger_metadata = PluginTriggerMetadata( + plugin_unique_identifier=provider_controller.plugin_unique_identifier or "", + endpoint_id=subscription.endpoint_id, + provider_id=subscription.provider_id, + event_name=event_name, + icon_filename=trigger_entity.identity.icon or "", + icon_dark_filename=trigger_entity.identity.icon_dark or "", + ) + + quota_charge = unlimited() + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, subscription.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(subscription.tenant_id) + logger.info("Tenant %s rate limited, skipping plugin trigger %s", subscription.tenant_id, plugin_trigger.id) + return dispatched_count + + node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(event_node) + invoke_response: TriggerInvokeEventResponse | None = None + + with session_factory.create_session() as session: try: invoke_response = TriggerManager.invoke_trigger_event( tenant_id=subscription.tenant_id, @@ -387,6 +387,7 @@ def dispatch_triggered_workflow( raise ValueError(f"End user not found for app {plugin_trigger.app_id}") AsyncWorkflowService.trigger_workflow_async(session=session, user=end_user, trigger_data=trigger_data) + quota_charge.commit() dispatched_count += 1 logger.info( "Triggered workflow for app %s with trigger event %s", @@ -401,7 +402,7 @@ def dispatch_triggered_workflow( plugin_trigger.app_id, ) - return dispatched_count + return dispatched_count def dispatch_triggered_workflows( diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index 0c7f74c180..5ca04fd7c2 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -7,13 +7,14 @@ improving performance by offloading storage operations to background workers. import json import logging +from typing import Any from celery import shared_task -from graphon.entities import WorkflowExecution -from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom @@ -23,7 +24,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) def save_workflow_execution_task( self, - execution_data: dict, + execution_data: dict[str, Any], tenant_id: str, app_id: str, triggered_from: str, diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index f25ebe3bae..0d5475a56d 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -7,15 +7,16 @@ improving performance by offloading storage operations to background workers. import json import logging +from typing import Any from celery import shared_task +from sqlalchemy import select + +from core.db.session_factory import session_factory from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, ) from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter -from sqlalchemy import select - -from core.db.session_factory import session_factory from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -25,7 +26,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) def save_workflow_node_execution_task( self, - execution_data: dict, + execution_data: dict[str, Any], tenant_id: str, app_id: str, triggered_from: str, diff --git a/api/tasks/workflow_schedule_tasks.py b/api/tasks/workflow_schedule_tasks.py index 8c64d3ab27..7638652000 100644 --- a/api/tasks/workflow_schedule_tasks.py +++ b/api/tasks/workflow_schedule_tasks.py @@ -8,10 +8,11 @@ from core.workflow.nodes.trigger_schedule.exc import ( ScheduleNotFoundError, TenantOwnerNotFoundError, ) -from enums.quota_type import QuotaType, unlimited +from enums.quota_type import QuotaType from models.trigger import WorkflowSchedulePlan from services.async_workflow_service import AsyncWorkflowService from services.errors.app import QuotaExceededError +from services.quota_service import QuotaService, unlimited from services.trigger.app_trigger_service import AppTriggerService from services.trigger.schedule_service import ScheduleService from services.workflow.entities import ScheduleTriggerData @@ -32,6 +33,7 @@ def run_schedule_trigger(schedule_id: str) -> None: TenantOwnerNotFoundError: If no owner/admin for tenant ScheduleExecutionError: If workflow trigger fails """ + # Ensure expire_on_commit is set to False to remain schedule/tenant_owner available with session_factory.create_session() as session: schedule = session.get(WorkflowSchedulePlan, schedule_id) if not schedule: @@ -41,16 +43,16 @@ def run_schedule_trigger(schedule_id: str) -> None: if not tenant_owner: raise TenantOwnerNotFoundError(f"No owner or admin found for tenant {schedule.tenant_id}") - quota_charge = unlimited() - try: - quota_charge = QuotaType.TRIGGER.consume(schedule.tenant_id) - except QuotaExceededError: - AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) - logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) - return + quota_charge = unlimited() + try: + quota_charge = QuotaService.reserve(QuotaType.TRIGGER, schedule.tenant_id) + except QuotaExceededError: + AppTriggerService.mark_tenant_triggers_rate_limited(schedule.tenant_id) + logger.info("Tenant %s rate limited, skipping schedule trigger %s", schedule.tenant_id, schedule_id) + return - try: - # Production dispatch: Trigger the workflow normally + try: + with session_factory.create_session() as session: response = AsyncWorkflowService.trigger_workflow_async( session=session, user=tenant_owner, @@ -61,9 +63,10 @@ def run_schedule_trigger(schedule_id: str) -> None: tenant_id=schedule.tenant_id, ), ) - logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) - except Exception as e: - quota_charge.refund() - raise ScheduleExecutionError( - f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}" - ) from e + quota_charge.commit() + logger.info("Schedule %s triggered workflow: %s", schedule_id, response.workflow_trigger_log_id) + except Exception as e: + quota_charge.refund() + raise ScheduleExecutionError( + f"Failed to trigger workflow for schedule {schedule_id}, app {schedule.app_id}" + ) from e diff --git a/api/templates/without-brand/workflow_comment_mention_template_en-US.html b/api/templates/without-brand/workflow_comment_mention_template_en-US.html new file mode 100644 index 0000000000..1ef8fe4e3f --- /dev/null +++ b/api/templates/without-brand/workflow_comment_mention_template_en-US.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

You were mentioned in a workflow comment

+
+

Hi {{ mentioned_name }},

+

{{ commenter_name }} mentioned you in {{ app_name }}.

+
+
+

{{ comment_content }}

+
+

Open {{ application_title }} to reply to the comment.

+
+ + + diff --git a/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html b/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html new file mode 100644 index 0000000000..8b9b2dbe71 --- /dev/null +++ b/api/templates/without-brand/workflow_comment_mention_template_zh-CN.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

你在工作流评论中被提及

+
+

你好,{{ mentioned_name }}:

+

{{ commenter_name }} 在 {{ app_name }} 中提及了你。

+
+
+

{{ comment_content }}

+
+

请在 {{ application_title }} 中查看并回复此评论。

+
+ + + diff --git a/api/templates/workflow_comment_mention_template_en-US.html b/api/templates/workflow_comment_mention_template_en-US.html new file mode 100644 index 0000000000..1ef8fe4e3f --- /dev/null +++ b/api/templates/workflow_comment_mention_template_en-US.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

You were mentioned in a workflow comment

+
+

Hi {{ mentioned_name }},

+

{{ commenter_name }} mentioned you in {{ app_name }}.

+
+
+

{{ comment_content }}

+
+

Open {{ application_title }} to reply to the comment.

+
+ + + diff --git a/api/templates/workflow_comment_mention_template_zh-CN.html b/api/templates/workflow_comment_mention_template_zh-CN.html new file mode 100644 index 0000000000..8b9b2dbe71 --- /dev/null +++ b/api/templates/workflow_comment_mention_template_zh-CN.html @@ -0,0 +1,119 @@ + + + + + + + + +
+
+ Dify Logo +
+

你在工作流评论中被提及

+
+

你好,{{ mentioned_name }}:

+

{{ commenter_name }} 在 {{ app_name }} 中提及了你。

+
+
+

{{ comment_content }}

+
+

请在 {{ application_title }} 中查看并回复此评论。

+
+ + + diff --git a/api/tests/__init__.py b/api/tests/__init__.py index e69de29bb2..ced6188ce8 100644 --- a/api/tests/__init__.py +++ b/api/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite root package (enables ``import tests.integration_tests...`` with ``pythonpath = .``).""" diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index f84d39aeb5..c07ab6d6bf 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -33,6 +33,7 @@ REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false REDIS_DB=0 +REDIS_KEY_PREFIX= # PostgreSQL database configuration DB_USERNAME=postgres diff --git a/api/tests/integration_tests/__init__.py b/api/tests/integration_tests/__init__.py index e69de29bb2..c66cd71b7e 100644 --- a/api/tests/integration_tests/__init__.py +++ b/api/tests/integration_tests/__init__.py @@ -0,0 +1 @@ +"""Integration tests package.""" diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 44adadeaa5..09078d196d 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -8,6 +8,7 @@ from collections.abc import Generator import pytest from flask import Flask from flask.testing import FlaskClient +from sqlalchemy import delete, select from sqlalchemy.orm import Session from app_factory import create_app @@ -47,7 +48,7 @@ os.environ["OPENDAL_FS_ROOT"] = "/tmp/dify-storage" os.environ.setdefault("STORAGE_TYPE", "opendal") os.environ.setdefault("OPENDAL_SCHEME", "fs") -_CACHED_APP = create_app() +_SIO_APP, _CACHED_APP = create_app() @pytest.fixture(scope="session") @@ -83,15 +84,15 @@ def setup_account(request) -> Generator[Account, None, None]: with _CACHED_APP.test_request_context(): with Session(bind=db.engine, expire_on_commit=False) as session: - account = session.query(Account).filter_by(email=email).one() + account = session.scalars(select(Account).filter_by(email=email)).one() yield account with _CACHED_APP.test_request_context(): - db.session.query(DifySetup).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Account).delete() - db.session.query(Tenant).delete() + db.session.execute(delete(DifySetup)) + db.session.execute(delete(TenantAccountJoin)) + db.session.execute(delete(Account)) + db.session.execute(delete(Tenant)) db.session.commit() diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index d10e5ed13c..3b5e822b90 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -171,35 +171,13 @@ class TestChatMessageApiPermissions: parent_message_id=None, ) - class MockQuery: - def __init__(self, model): - self.model = model - - def where(self, *args, **kwargs): - return self - - def first(self): - if getattr(self.model, "__name__", "") == "Conversation": - return mock_conversation - return None - - def order_by(self, *args, **kwargs): - return self - - def limit(self, *_): - return self - - def all(self): - if getattr(self.model, "__name__", "") == "Message": - return [mock_message] - return [] - mock_session = mock.Mock() - mock_session.query.side_effect = MockQuery - mock_session.scalar.return_value = False + mock_session.scalar.return_value = mock_conversation + mock_session.scalars.return_value.all.return_value = [mock_message] monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session)) monkeypatch.setattr(message_api, "current_user", mock_account) + monkeypatch.setattr(message_api, "attach_message_extra_contents", mock.Mock()) class DummyPagination: def __init__(self, data, limit, has_more): diff --git a/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py deleted file mode 100644 index 038f37af5f..0000000000 --- a/api/tests/integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ /dev/null @@ -1,47 +0,0 @@ -import uuid -from unittest import mock - -from controllers.console.app import workflow_draft_variable as draft_variable_api -from controllers.console.app import wraps -from factories.variable_factory import build_segment -from models import App, AppMode -from models.workflow import WorkflowDraftVariable -from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService - - -def _get_mock_srv_class() -> type[WorkflowDraftVariableService]: - return mock.create_autospec(WorkflowDraftVariableService) - - -class TestWorkflowDraftNodeVariableListApi: - def test_get(self, test_client, auth_header, monkeypatch): - srv_class = _get_mock_srv_class() - mock_app_model: App = App() - mock_app_model.id = str(uuid.uuid4()) - test_node_id = "test_node_id" - mock_app_model.mode = AppMode.ADVANCED_CHAT - mock_load_app_model = mock.Mock(return_value=mock_app_model) - - monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) - - var1 = WorkflowDraftVariable.new_node_variable( - app_id="test_app_1", - node_id="test_node_1", - name="str_var", - value=build_segment("str_value"), - node_execution_id=str(uuid.uuid4()), - ) - srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True) - srv_class.return_value = srv_instance - srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1]) - - response = test_client.get( - f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables", - headers=auth_header, - ) - assert response.status_code == 200 - response_dict = response.json - assert isinstance(response_dict, dict) - assert "items" in response_dict - assert len(response_dict["items"]) == 1 diff --git a/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py b/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py deleted file mode 100644 index e55c12e678..0000000000 --- a/api/tests/integration_tests/controllers/console/workspace/test_trigger_provider_permissions.py +++ /dev/null @@ -1,244 +0,0 @@ -"""Integration tests for Trigger Provider subscription permission verification.""" - -import uuid -from unittest import mock - -import pytest -from flask.testing import FlaskClient - -from controllers.console.workspace import trigger_providers as trigger_providers_api -from libs.datetime_utils import naive_utc_now -from models import Tenant -from models.account import Account, TenantAccountJoin, TenantAccountRole - - -class TestTriggerProviderSubscriptionPermissions: - """Test permission verification for Trigger Provider subscription endpoints.""" - - @pytest.fixture - def mock_account(self, monkeypatch: pytest.MonkeyPatch): - """Create a mock Account for testing.""" - - account = Account(name="Test User", email="test@example.com") - account.id = str(uuid.uuid4()) - account.last_active_at = naive_utc_now() - account.created_at = naive_utc_now() - account.updated_at = naive_utc_now() - - # Create mock tenant - tenant = Tenant(name="Test Tenant") - tenant.id = str(uuid.uuid4()) - - mock_session_instance = mock.Mock() - - mock_tenant_join = TenantAccountJoin(role=TenantAccountRole.OWNER) - monkeypatch.setattr(mock_session_instance, "scalar", mock.Mock(return_value=mock_tenant_join)) - - mock_scalars_result = mock.Mock() - mock_scalars_result.one.return_value = tenant - monkeypatch.setattr(mock_session_instance, "scalars", mock.Mock(return_value=mock_scalars_result)) - - mock_session_context = mock.Mock() - mock_session_context.__enter__.return_value = mock_session_instance - monkeypatch.setattr("models.account.Session", lambda _, expire_on_commit: mock_session_context) - - account.current_tenant = tenant - account.current_tenant_id = tenant.id - return account - - @pytest.mark.parametrize( - ("role", "list_status", "get_status", "update_status", "create_status", "build_status", "delete_status"), - [ - # Admin/Owner can do everything - (TenantAccountRole.OWNER, 200, 200, 200, 200, 200, 200), - (TenantAccountRole.ADMIN, 200, 200, 200, 200, 200, 200), - # Editor can list, get, update (parameters), but not create, build, or delete - (TenantAccountRole.EDITOR, 200, 200, 200, 403, 403, 403), - # Normal user cannot do anything - (TenantAccountRole.NORMAL, 403, 403, 403, 403, 403, 403), - # Dataset operator cannot do anything - (TenantAccountRole.DATASET_OPERATOR, 403, 403, 403, 403, 403, 403), - ], - ) - def test_trigger_subscription_permissions( - self, - test_client: FlaskClient, - auth_header, - monkeypatch, - mock_account, - role: TenantAccountRole, - list_status: int, - get_status: int, - update_status: int, - create_status: int, - build_status: int, - delete_status: int, - ): - """Test that different roles have appropriate permissions for trigger subscription operations.""" - # Set user role - mock_account.role = role - - # Mock current user - monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) - - # Mock AccountService.load_user to prevent authentication issues - from services.account_service import AccountService - - mock_load_user = mock.Mock(return_value=mock_account) - monkeypatch.setattr(AccountService, "load_user", mock_load_user) - - # Test data - provider = "some_provider/some_trigger" - subscription_builder_id = str(uuid.uuid4()) - subscription_id = str(uuid.uuid4()) - - # Mock service methods - mock_list_subscriptions = mock.Mock(return_value=[]) - monkeypatch.setattr( - "services.trigger.trigger_provider_service.TriggerProviderService.list_trigger_provider_subscriptions", - mock_list_subscriptions, - ) - - mock_get_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.get_subscription_builder_by_id", - mock_get_subscription_builder, - ) - - mock_update_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", - mock_update_subscription_builder, - ) - - mock_create_subscription_builder = mock.Mock(return_value={"id": subscription_builder_id}) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", - mock_create_subscription_builder, - ) - - mock_update_and_build_builder = mock.Mock() - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.update_and_build_builder", - mock_update_and_build_builder, - ) - - mock_delete_provider = mock.Mock() - mock_delete_plugin_trigger = mock.Mock() - mock_db_session = mock.Mock() - mock_db_session.commit = mock.Mock() - - def mock_session_func(engine=None): - return mock_session_context - - mock_session_context = mock.Mock() - mock_session_context.__enter__.return_value = mock_db_session - mock_session_context.__exit__.return_value = None - - monkeypatch.setattr("services.trigger.trigger_provider_service.Session", mock_session_func) - monkeypatch.setattr("services.trigger.trigger_subscription_operator_service.Session", mock_session_func) - - monkeypatch.setattr( - "services.trigger.trigger_provider_service.TriggerProviderService.delete_trigger_provider", - mock_delete_provider, - ) - monkeypatch.setattr( - "services.trigger.trigger_subscription_operator_service.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription", - mock_delete_plugin_trigger, - ) - - # Test 1: List subscriptions (should work for Editor, Admin, Owner) - response = test_client.get( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/list", - headers=auth_header, - ) - assert response.status_code == list_status - - # Test 2: Get subscription builder (should work for Editor, Admin, Owner) - response = test_client.get( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/{subscription_builder_id}", - headers=auth_header, - ) - assert response.status_code == get_status - - # Test 3: Update subscription builder parameters (should work for Editor, Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/update/{subscription_builder_id}", - headers=auth_header, - json={"parameters": {"webhook_url": "https://example.com/webhook"}}, - ) - assert response.status_code == update_status - - # Test 4: Create subscription builder (should only work for Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/create", - headers=auth_header, - json={"credential_type": "api_key"}, - ) - assert response.status_code == create_status - - # Test 5: Build/activate subscription (should only work for Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/build/{subscription_builder_id}", - headers=auth_header, - json={"name": "Test Subscription"}, - ) - assert response.status_code == build_status - - # Test 6: Delete subscription (should only work for Admin, Owner) - response = test_client.post( - f"/console/api/workspaces/current/trigger-provider/{subscription_id}/subscriptions/delete", - headers=auth_header, - ) - assert response.status_code == delete_status - - @pytest.mark.parametrize( - ("role", "status"), - [ - (TenantAccountRole.OWNER, 200), - (TenantAccountRole.ADMIN, 200), - # Editor should be able to access logs for debugging - (TenantAccountRole.EDITOR, 200), - (TenantAccountRole.NORMAL, 403), - (TenantAccountRole.DATASET_OPERATOR, 403), - ], - ) - def test_trigger_subscription_logs_permissions( - self, - test_client: FlaskClient, - auth_header, - monkeypatch, - mock_account, - role: TenantAccountRole, - status: int, - ): - """Test that different roles have appropriate permissions for accessing subscription logs.""" - # Set user role - mock_account.role = role - - # Mock current user - monkeypatch.setattr(trigger_providers_api, "current_user", mock_account) - - # Mock AccountService.load_user to prevent authentication issues - from services.account_service import AccountService - - mock_load_user = mock.Mock(return_value=mock_account) - monkeypatch.setattr(AccountService, "load_user", mock_load_user) - - # Test data - provider = "some_provider/some_trigger" - subscription_builder_id = str(uuid.uuid4()) - - # Mock service method - mock_list_logs = mock.Mock(return_value=[]) - monkeypatch.setattr( - "services.trigger.trigger_subscription_builder_service.TriggerSubscriptionBuilderService.list_logs", - mock_list_logs, - ) - - # Test access to logs - response = test_client.get( - f"/console/api/workspaces/current/trigger-provider/{provider}/subscriptions/builder/logs/{subscription_builder_id}", - headers=auth_header, - ) - assert response.status_code == status diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index 91245e879e..a876b0c4aa 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -1,9 +1,8 @@ from collections.abc import Generator -from graphon.node_events import StreamCompletedEvent - from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage +from graphon.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3fdea10976..2392084c36 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,8 +1,8 @@ -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamCompletedEvent - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent class _Seg: @@ -70,19 +70,16 @@ def test_node_integration_minimal_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=_GP(), graph_runtime_state=_GS(vp), ) diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py deleted file mode 100644 index c1bb8e1245..0000000000 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ /dev/null @@ -1,375 +0,0 @@ -import unittest -from datetime import UTC, datetime -from unittest.mock import patch -from uuid import uuid4 - -import pytest -from graphon.file import File, FileTransferMethod, FileType -from sqlalchemy.orm import Session - -from core.app.file_access import DatabaseFileAccessController -from extensions.ext_database import db -from extensions.storage.storage_type import StorageType -from factories.file_factory import StorageKeyLoader -from models import ToolFile, UploadFile -from models.enums import CreatorUserRole - - -@pytest.mark.usefixtures("flask_req_ctx") -class TestStorageKeyLoader(unittest.TestCase): - """ - Integration tests for StorageKeyLoader class. - - Tests the batched loading of storage keys from the database for files - with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE. - """ - - def setUp(self): - """Set up test data before each test method.""" - self.session = db.session() - self.tenant_id = str(uuid4()) - self.user_id = str(uuid4()) - self.conversation_id = str(uuid4()) - - # Create test data that will be cleaned up after each test - self.test_upload_files = [] - self.test_tool_files = [] - - # Create StorageKeyLoader instance - self.loader = StorageKeyLoader( - self.session, - self.tenant_id, - access_controller=DatabaseFileAccessController(), - ) - - def tearDown(self): - """Clean up test data after each test method.""" - self.session.rollback() - - def _create_upload_file( - self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None - ) -> UploadFile: - """Helper method to create an UploadFile record for testing.""" - if file_id is None: - file_id = str(uuid4()) - if storage_key is None: - storage_key = f"test_storage_key_{uuid4()}" - if tenant_id is None: - tenant_id = self.tenant_id - - upload_file = UploadFile( - tenant_id=tenant_id, - storage_type=StorageType.LOCAL, - key=storage_key, - name="test_file.txt", - size=1024, - extension=".txt", - mime_type="text/plain", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=self.user_id, - created_at=datetime.now(UTC), - used=False, - ) - upload_file.id = file_id - - self.session.add(upload_file) - self.session.flush() - self.test_upload_files.append(upload_file) - - return upload_file - - def _create_tool_file( - self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None - ) -> ToolFile: - """Helper method to create a ToolFile record for testing.""" - if file_id is None: - file_id = str(uuid4()) - if file_key is None: - file_key = f"test_file_key_{uuid4()}" - if tenant_id is None: - tenant_id = self.tenant_id - - tool_file = ToolFile( - user_id=self.user_id, - tenant_id=tenant_id, - conversation_id=self.conversation_id, - file_key=file_key, - mimetype="text/plain", - original_url="http://example.com/file.txt", - name="test_tool_file.txt", - size=2048, - ) - tool_file.id = file_id - self.session.add(tool_file) - self.session.flush() - self.test_tool_files.append(tool_file) - - return tool_file - - def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File: - """Helper method to create a File object for testing.""" - if tenant_id is None: - tenant_id = self.tenant_id - - # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods - file_related_id = None - remote_url = None - - if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): - file_related_id = related_id - elif transfer_method == FileTransferMethod.REMOTE_URL: - remote_url = "https://example.com/test_file.txt" - file_related_id = related_id - - return File( - id=str(uuid4()), # Generate new UUID for File.id - tenant_id=tenant_id, - type=FileType.DOCUMENT, - transfer_method=transfer_method, - related_id=file_related_id, - remote_url=remote_url, - filename="test_file.txt", - extension=".txt", - mime_type="text/plain", - size=1024, - storage_key="initial_key", - ) - - def test_load_storage_keys_local_file(self): - """Test loading storage keys for LOCAL_FILE transfer method.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Load storage keys - self.loader.load_storage_keys([file]) - - # Verify storage key was loaded correctly - assert file._storage_key == upload_file.key - - def test_load_storage_keys_remote_url(self): - """Test loading storage keys for REMOTE_URL transfer method.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) - - # Load storage keys - self.loader.load_storage_keys([file]) - - # Verify storage key was loaded correctly - assert file._storage_key == upload_file.key - - def test_load_storage_keys_tool_file(self): - """Test loading storage keys for TOOL_FILE transfer method.""" - # Create test data - tool_file = self._create_tool_file() - file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) - - # Load storage keys - self.loader.load_storage_keys([file]) - - # Verify storage key was loaded correctly - assert file._storage_key == tool_file.file_key - - def test_load_storage_keys_mixed_methods(self): - """Test batch loading with mixed transfer methods.""" - # Create test data for different transfer methods - upload_file1 = self._create_upload_file() - upload_file2 = self._create_upload_file() - tool_file = self._create_tool_file() - - file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) - file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) - - files = [file1, file2, file3] - - # Load storage keys - self.loader.load_storage_keys(files) - - # Verify all storage keys were loaded correctly - assert file1._storage_key == upload_file1.key - assert file2._storage_key == upload_file2.key - assert file3._storage_key == tool_file.file_key - - def test_load_storage_keys_empty_list(self): - """Test with empty file list.""" - # Should not raise any exceptions - self.loader.load_storage_keys([]) - - def test_load_storage_keys_ignores_legacy_file_tenant_id(self): - """Legacy file tenant_id should not override the loader tenant scope.""" - upload_file = self._create_upload_file() - file = self._create_file( - related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) - ) - - self.loader.load_storage_keys([file]) - - assert file._storage_key == upload_file.key - - def test_load_storage_keys_missing_file_id(self): - """Test with None file.related_id.""" - # Create a file with valid parameters first, then manually set related_id to None - file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) - file.related_id = None - - # Should raise ValueError for None file related_id - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) - - assert str(context.value) == "file id should not be None." - - def test_load_storage_keys_nonexistent_upload_file_records(self): - """Test with missing UploadFile database records.""" - # Create file with non-existent upload file id - non_existent_id = str(uuid4()) - file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Should raise ValueError for missing record - with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) - - def test_load_storage_keys_nonexistent_tool_file_records(self): - """Test with missing ToolFile database records.""" - # Create file with non-existent tool file id - non_existent_id = str(uuid4()) - file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) - - # Should raise ValueError for missing record - with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) - - def test_load_storage_keys_invalid_uuid(self): - """Test with invalid UUID format.""" - # Create a file with valid parameters first, then manually set invalid related_id - file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) - file.related_id = "invalid-uuid-format" - - # Should raise ValueError for invalid UUID - with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) - - def test_load_storage_keys_batch_efficiency(self): - """Test batched operations use efficient queries.""" - # Create multiple files of different types - upload_files = [self._create_upload_file() for _ in range(3)] - tool_files = [self._create_tool_file() for _ in range(2)] - - files = [] - files.extend( - [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] - ) - files.extend( - [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] - ) - - # Mock the session to count queries - with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: - self.loader.load_storage_keys(files) - - # Should make exactly 2 queries (one for upload_files, one for tool_files) - assert mock_scalars.call_count == 2 - - # Verify all storage keys were loaded correctly - for i, file in enumerate(files[:3]): - assert file._storage_key == upload_files[i].key - for i, file in enumerate(files[3:]): - assert file._storage_key == tool_files[i].file_key - - def test_load_storage_keys_tenant_isolation(self): - """Test that tenant isolation works correctly.""" - # Create files for different tenants - other_tenant_id = str(uuid4()) - - # Create upload file for current tenant - upload_file_current = self._create_upload_file() - file_current = self._create_file( - related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE - ) - - # Create upload file for other tenant (but don't add to cleanup list) - upload_file_other = UploadFile( - tenant_id=other_tenant_id, - storage_type=StorageType.LOCAL, - key="other_tenant_key", - name="other_file.txt", - size=1024, - extension=".txt", - mime_type="text/plain", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=self.user_id, - created_at=datetime.now(UTC), - used=False, - ) - upload_file_other.id = str(uuid4()) - self.session.add(upload_file_other) - self.session.flush() - - # Create file for other tenant but try to load with current tenant's loader - file_other = self._create_file( - related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id - ) - - # Should raise ValueError due to tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file_other]) - - assert "Upload file not found for id:" in str(context.value) - - # Current tenant's file should still work - self.loader.load_storage_keys([file_current]) - assert file_current._storage_key == upload_file_current.key - - def test_load_storage_keys_mixed_tenant_batch(self): - """Test batch with mixed tenant files (should fail on first mismatch).""" - # Create files for current tenant - upload_file_current = self._create_upload_file() - file_current = self._create_file( - related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE - ) - - # Create file for different tenant - other_tenant_id = str(uuid4()) - file_other = self._create_file( - related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id - ) - - # Should raise ValueError on tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file_current, file_other]) - - assert "Upload file not found for id:" in str(context.value) - - def test_load_storage_keys_duplicate_file_ids(self): - """Test handling of duplicate file IDs in the batch.""" - # Create upload file - upload_file = self._create_upload_file() - - # Create two File objects with same related_id - file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Should handle duplicates gracefully - self.loader.load_storage_keys([file1, file2]) - - # Both files should have the same storage key - assert file1._storage_key == upload_file.key - assert file2._storage_key == upload_file.key - - def test_load_storage_keys_session_isolation(self): - """Test that the loader uses the provided session correctly.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Create loader with different session (same underlying connection) - - with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader( - other_session, - self.tenant_id, - access_controller=DatabaseFileAccessController(), - ) - with pytest.raises(ValueError): - other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index ce04a158a8..c4146d5ccd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,6 +4,9 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient + # import monkeypatch from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.llm_entities import ( @@ -23,9 +26,6 @@ from graphon.model_runtime.entities.model_entities import ( ) from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient - class MockModelClass(PluginModelClient): def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: diff --git a/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py b/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py index 951a5ab4b4..0a19debc39 100644 --- a/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py +++ b/api/tests/integration_tests/services/plugin/test_plugin_lifecycle.py @@ -1,5 +1,5 @@ import pytest -from sqlalchemy import delete +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from models import Tenant @@ -61,7 +61,11 @@ class TestPluginPermissionLifecycle: assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS with session_factory.create_session() as session: - count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count() + count = session.scalar( + select(func.count()) + .select_from(TenantPluginPermission) + .where(TenantPluginPermission.tenant_id == tenant) + ) assert count == 1 diff --git a/api/tests/integration_tests/services/retention/test_messages_clean_service.py b/api/tests/integration_tests/services/retention/test_messages_clean_service.py index 348bb0af4a..352960bcc2 100644 --- a/api/tests/integration_tests/services/retention/test_messages_clean_service.py +++ b/api/tests/integration_tests/services/retention/test_messages_clean_service.py @@ -3,7 +3,7 @@ import math import uuid import pytest -from sqlalchemy import delete +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from models import Tenant @@ -210,7 +210,7 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 with session_factory.create_session() as session: - remaining = session.query(Message).where(Message.id.in_(all_ids)).count() + remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids))) assert remaining == len(all_ids) def test_billing_disabled_deletes_all_in_range(self, seed_messages): @@ -231,7 +231,7 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == len(all_ids) with session_factory.create_session() as session: - remaining = session.query(Message).where(Message.id.in_(all_ids)).count() + remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids))) assert remaining == 0 def test_start_from_filters_correctly(self, seed_messages): @@ -254,7 +254,7 @@ class TestMessagesCleanServiceIntegration: with session_factory.create_session() as session: all_ids = list(msg_ids.values()) - remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()} + remaining_ids = set(session.scalars(select(Message.id).where(Message.id.in_(all_ids))).all()) assert msg_ids["old"] not in remaining_ids assert msg_ids["very_old"] in remaining_ids @@ -282,7 +282,7 @@ class TestMessagesCleanServiceIntegration: assert stats["batches"] >= expected_batches with session_factory.create_session() as session: - remaining = session.query(Message).where(Message.id.in_(msg_ids)).count() + remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(msg_ids))) assert remaining == 0 def test_no_messages_in_range_returns_empty_stats(self, seed_messages): @@ -319,9 +319,17 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 with session_factory.create_session() as session: - assert session.query(Message).where(Message.id == msg_id).count() == 0 - assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0 - assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0 + assert session.scalar(select(func.count()).select_from(Message).where(Message.id == msg_id)) == 0 + assert ( + session.scalar(select(func.count()).select_from(MessageFeedback).where(MessageFeedback.id == fb_id)) + == 0 + ) + assert ( + session.scalar( + select(func.count()).select_from(MessageAnnotation).where(MessageAnnotation.id == ann_id) + ) + == 0 + ) def test_factory_from_time_range_validation(self): with pytest.raises(ValueError, match="start_from"): diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 5c6636f31e..e130644338 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,11 +3,7 @@ import unittest import uuid import pytest -from graphon.nodes import BuiltinNodeTypes -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType -from graphon.variables.variables import StringVariable -from sqlalchemy import delete +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID @@ -15,6 +11,10 @@ from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from libs import datetime_utils from models.enums import CreatorUserRole from models.model import UploadFile @@ -38,21 +38,25 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def setUp(self): self._test_app_id = str(uuid.uuid4()) + self._test_user_id = str(uuid.uuid4()) self._session: Session = db.session() sys_var = WorkflowDraftVariable.new_sys_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="sys_var", value=build_segment("sys_value"), node_execution_id=self._node_exec_id, ) conv_var = WorkflowDraftVariable.new_conversation_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="conv_var", value=build_segment("conv_value"), ) node2_vars = [ WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node2_id, name="int_var", value=build_segment(1), @@ -61,6 +65,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): ), WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node2_id, name="str_var", value=build_segment("str_value"), @@ -70,6 +75,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): ] node1_var = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node1_id, name="str_var", value=build_segment("str_value"), @@ -141,24 +147,27 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_delete_node_variables(self): srv = self._get_test_srv() srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id) - node2_var_count = ( - self._session.query(WorkflowDraftVariable) + node2_var_count = self._session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariable) .where( WorkflowDraftVariable.app_id == self._test_app_id, WorkflowDraftVariable.node_id == self._node2_id, + WorkflowDraftVariable.user_id == self._test_user_id, ) - .count() ) assert node2_var_count == 0 def test_delete_variable(self): srv = self._get_test_srv() - node_1_var = ( - self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one() - ) + node_1_var = self._session.scalars( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id) + ).one() srv.delete_variable(node_1_var) exists = bool( - self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first() + self._session.scalars( + select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id) + ).first() ) assert exists is False @@ -248,9 +257,7 @@ class TestDraftVariableLoader(unittest.TestCase): def tearDown(self): with Session(bind=db.engine, expire_on_commit=False) as session: - session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete( - synchronize_session=False - ) + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id)) session.commit() def test_variable_loader_with_empty_selector(self): @@ -431,9 +438,11 @@ class TestDraftVariableLoader(unittest.TestCase): # Clean up with Session(bind=db.engine) as session: # Query and delete by ID to ensure they're tracked in this session - session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() - session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() - session.query(UploadFile).filter_by(id=upload_file.id).delete() + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id)) + session.execute( + delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id) + ) + session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id)) session.commit() # Clean up storage try: @@ -534,9 +543,11 @@ class TestDraftVariableLoader(unittest.TestCase): # Clean up with Session(bind=db.engine) as session: # Query and delete by ID to ensure they're tracked in this session - session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete() - session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete() - session.query(UploadFile).filter_by(id=upload_file.id).delete() + session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id)) + session.execute( + delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id) + ) + session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id)) session.commit() # Clean up storage try: diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index 38dc8bbb28..4f444598b1 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,11 @@ import uuid from unittest.mock import patch import pytest -from graphon.variables.segments import StringSegment -from sqlalchemy import delete +from sqlalchemy import delete, func, select from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -108,8 +108,12 @@ class TestDeleteDraftVariablesIntegration: app2_id = data["app2"].id with session_factory.create_session() as session: - app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() - app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + app1_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) + app2_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app2_id) + ) assert app1_vars_before == 5 assert app2_vars_before == 5 @@ -117,8 +121,12 @@ class TestDeleteDraftVariablesIntegration: assert deleted_count == 5 with session_factory.create_session() as session: - app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() - app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count() + app1_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) + app2_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app2_id) + ) assert app1_vars_after == 0 assert app2_vars_after == 5 @@ -130,7 +138,9 @@ class TestDeleteDraftVariablesIntegration: assert deleted_count == 5 with session_factory.create_session() as session: - remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + remaining_vars = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) assert remaining_vars == 0 def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data): @@ -143,14 +153,18 @@ class TestDeleteDraftVariablesIntegration: app1_id = data["app1"].id with session_factory.create_session() as session: - vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) assert vars_before == 5 deleted_count = _delete_draft_variables(app1_id) assert deleted_count == 5 with session_factory.create_session() as session: - vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count() + vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id) + ) assert vars_after == 0 def test_batch_deletion_handles_large_dataset(self, app_and_tenant): @@ -175,7 +189,9 @@ class TestDeleteDraftVariablesIntegration: deleted_count = delete_draft_variables_batch(app.id, batch_size=8) assert deleted_count == 25 with session_factory.create_session() as session: - remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count() + remaining = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app.id) + ) assert remaining == 0 finally: with session_factory.create_session() as session: @@ -193,7 +209,6 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant from graphon.variables.types import SegmentType - from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -307,13 +322,17 @@ class TestDeleteDraftVariablesWithOffloadIntegration: mock_storage.delete.return_value = None with session_factory.create_session() as session: - draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = ( - session.query(WorkflowDraftVariableFile) - .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) - .count() + draft_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) + var_files_before = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + ) + upload_files_before = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) ) - upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -322,16 +341,20 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert deleted_count == 3 with session_factory.create_session() as session: - draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + draft_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = ( - session.query(WorkflowDraftVariableFile) + var_files_after = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) - .count() ) - upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + upload_files_after = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) assert var_files_after == 0 assert upload_files_after == 0 @@ -352,16 +375,20 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert deleted_count == 3 with session_factory.create_session() as session: - draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + draft_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = ( - session.query(WorkflowDraftVariableFile) + var_files_after = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) - .count() ) - upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + upload_files_after = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) assert var_files_after == 0 assert upload_files_after == 0 @@ -425,7 +452,6 @@ class TestDeleteDraftVariablesSessionCommit: def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" from graphon.variables.types import SegmentType - from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant @@ -579,7 +605,9 @@ class TestDeleteDraftVariablesSessionCommit: # Verify all data was deleted (proves transaction was committed) with session_factory.create_session() as session: - remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + remaining_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert deleted_count == 10 assert remaining_count == 0 @@ -592,7 +620,9 @@ class TestDeleteDraftVariablesSessionCommit: # Verify initial state with session_factory.create_session() as session: - initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + initial_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert initial_count == 10 # Perform deletion with small batch size to force multiple commits @@ -602,13 +632,17 @@ class TestDeleteDraftVariablesSessionCommit: # Verify all data is deleted in a new session (proves commits worked) with session_factory.create_session() as session: - final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + final_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert final_count == 0 # Verify specific IDs are deleted with session_factory.create_session() as session: - remaining_vars = ( - session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count() + remaining_vars = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariable) + .where(WorkflowDraftVariable.id.in_(variable_ids)) ) assert remaining_vars == 0 @@ -626,7 +660,9 @@ class TestDeleteDraftVariablesSessionCommit: app_id = data["app"].id with session_factory.create_session() as session: - initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + initial_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert initial_count == 10 # Delete all in a single batch @@ -635,7 +671,9 @@ class TestDeleteDraftVariablesSessionCommit: # Verify data is persisted with session_factory.create_session() as session: - final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + final_count = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) assert final_count == 0 def test_invalid_batch_size_raises_error(self, setup_commit_test_data): @@ -659,13 +697,17 @@ class TestDeleteDraftVariablesSessionCommit: # Verify initial state with session_factory.create_session() as session: - draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = ( - session.query(WorkflowDraftVariableFile) - .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) - .count() + draft_vars_before = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) + var_files_before = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + ) + upload_files_before = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) ) - upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -676,13 +718,17 @@ class TestDeleteDraftVariablesSessionCommit: # Verify all data is persisted (deleted) in new session with session_factory.create_session() as session: - draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_after = ( - session.query(WorkflowDraftVariableFile) - .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) - .count() + draft_vars_after = session.scalar( + select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id) + ) + var_files_after = session.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + ) + upload_files_after = session.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) ) - upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_after == 0 assert var_files_after == 0 assert upload_files_after == 0 diff --git a/api/tests/integration_tests/vdb/matrixone/__init__.py b/api/tests/integration_tests/vdb/matrixone/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/milvus/__init__.py b/api/tests/integration_tests/vdb/milvus/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/myscale/__init__.py b/api/tests/integration_tests/vdb/myscale/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/oceanbase/__init__.py b/api/tests/integration_tests/vdb/oceanbase/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/opengauss/__init__.py b/api/tests/integration_tests/vdb/opengauss/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/opensearch/__init__.py b/api/tests/integration_tests/vdb/opensearch/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py b/api/tests/integration_tests/vdb/opensearch/test_opensearch.py deleted file mode 100644 index 81ebb1d2f7..0000000000 --- a/api/tests/integration_tests/vdb/opensearch/test_opensearch.py +++ /dev/null @@ -1,235 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.rag.datasource.vdb.field import Field -from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchConfig, OpenSearchVector -from core.rag.models.document import Document -from extensions import ext_redis - - -def get_example_text() -> str: - return "This is a sample text for testing purposes." - - -@pytest.fixture(scope="module") -def setup_mock_redis(): - ext_redis.redis_client.get = MagicMock(return_value=None) - ext_redis.redis_client.set = MagicMock(return_value=None) - - mock_redis_lock = MagicMock() - mock_redis_lock.__enter__ = MagicMock() - mock_redis_lock.__exit__ = MagicMock() - ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock) - - -class TestOpenSearchConfig: - def test_to_opensearch_params(self): - config = OpenSearchConfig( - host="localhost", - port=9200, - secure=True, - user="admin", - password="password", - ) - - params = config.to_opensearch_params() - - assert params["hosts"] == [{"host": "localhost", "port": 9200}] - assert params["use_ssl"] is True - assert params["verify_certs"] is True - assert params["connection_class"].__name__ == "Urllib3HttpConnection" - assert params["http_auth"] == ("admin", "password") - - @patch("boto3.Session", autospec=True) - @patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth", autospec=True) - def test_to_opensearch_params_with_aws_managed_iam( - self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock - ): - mock_credentials = MagicMock() - mock_boto_session.return_value.get_credentials.return_value = mock_credentials - - mock_auth_instance = mock_aws_signer_auth.return_value - aws_region = "ap-southeast-2" - aws_service = "aoss" - host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com" - port = 9201 - - config = OpenSearchConfig( - host=host, - port=port, - secure=True, - auth_method="aws_managed_iam", - aws_region=aws_region, - aws_service=aws_service, - ) - - params = config.to_opensearch_params() - - assert params["hosts"] == [{"host": host, "port": port}] - assert params["use_ssl"] is True - assert params["verify_certs"] is True - assert params["connection_class"].__name__ == "Urllib3HttpConnection" - assert params["http_auth"] is mock_auth_instance - - mock_aws_signer_auth.assert_called_once_with( - credentials=mock_credentials, region=aws_region, service=aws_service - ) - assert mock_boto_session.return_value.get_credentials.called - - -class TestOpenSearchVector: - def setup_method(self): - self.collection_name = "test_collection" - self.example_doc_id = "example_doc_id" - self.vector = OpenSearchVector( - collection_name=self.collection_name, - config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"), - ) - self.vector._client = MagicMock() - - @pytest.mark.parametrize( - ("search_response", "expected_length", "expected_doc_id"), - [ - ( - { - "hits": { - "total": {"value": 1}, - "hits": [ - { - "_source": { - "page_content": get_example_text(), - "metadata": {"document_id": "example_doc_id"}, - } - } - ], - } - }, - 1, - "example_doc_id", - ), - ({"hits": {"total": {"value": 0}, "hits": []}}, 0, None), - ], - ) - def test_search_by_full_text(self, search_response, expected_length, expected_doc_id): - self.vector._client.search.return_value = search_response - - hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) - assert len(hits_by_full_text) == expected_length - if expected_length > 0: - assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id - - def test_search_by_vector(self): - vector = [0.1] * 128 - mock_response = { - "hits": { - "total": {"value": 1}, - "hits": [ - { - "_source": { - Field.CONTENT_KEY: get_example_text(), - Field.METADATA_KEY: {"document_id": self.example_doc_id}, - }, - "_score": 1.0, - } - ], - } - } - self.vector._client.search.return_value = mock_response - - hits_by_vector = self.vector.search_by_vector(query_vector=vector) - - print("Hits by vector:", hits_by_vector) - print("Expected document ID:", self.example_doc_id) - print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits") - - assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}" - assert hits_by_vector[0].metadata["document_id"] == self.example_doc_id, ( - f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" - ) - - def test_get_ids_by_metadata_field(self): - mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} - self.vector._client.search.return_value = mock_response - - doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) - embedding = [0.1] * 128 - - with patch("opensearchpy.helpers.bulk", autospec=True) as mock_bulk: - mock_bulk.return_value = ([], []) - self.vector.add_texts([doc], [embedding]) - - ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) - assert len(ids) == 1 - assert ids[0] == "mock_id" - - def test_add_texts(self): - self.vector._client.index.return_value = {"result": "created"} - - doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id}) - embedding = [0.1] * 128 - - with patch("opensearchpy.helpers.bulk", autospec=True) as mock_bulk: - mock_bulk.return_value = ([], []) - self.vector.add_texts([doc], [embedding]) - - mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} - self.vector._client.search.return_value = mock_response - - ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) - assert len(ids) == 1 - assert ids[0] == "mock_id" - - def test_delete_nonexistent_index(self): - """Test deleting a non-existent index.""" - # Create a vector instance with a non-existent collection name - self.vector._client.indices.exists.return_value = False - - # Should not raise an exception - self.vector.delete() - - # Verify that exists was called but delete was not - self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower()) - self.vector._client.indices.delete.assert_not_called() - - def test_delete_existing_index(self): - """Test deleting an existing index.""" - self.vector._client.indices.exists.return_value = True - - self.vector.delete() - - # Verify both exists and delete were called - self.vector._client.indices.exists.assert_called_once_with(index=self.collection_name.lower()) - self.vector._client.indices.delete.assert_called_once_with(index=self.collection_name.lower()) - - -@pytest.mark.usefixtures("setup_mock_redis") -class TestOpenSearchVectorWithRedis: - def setup_method(self): - self.tester = TestOpenSearchVector() - - def test_search_by_full_text(self): - self.tester.setup_method() - search_response = { - "hits": { - "total": {"value": 1}, - "hits": [ - {"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}} - ], - } - } - expected_length = 1 - expected_doc_id = "example_doc_id" - self.tester.test_search_by_full_text(search_response, expected_length, expected_doc_id) - - def test_get_ids_by_metadata_field(self): - self.tester.setup_method() - self.tester.test_get_ids_by_metadata_field() - - def test_add_texts(self): - self.tester.setup_method() - self.tester.test_add_texts() - - def test_search_by_vector(self): - self.tester.setup_method() - self.tester.test_search_by_vector() diff --git a/api/tests/integration_tests/vdb/oracle/__init__.py b/api/tests/integration_tests/vdb/oracle/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/pgvecto_rs/__init__.py b/api/tests/integration_tests/vdb/pgvecto_rs/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/pgvector/__init__.py b/api/tests/integration_tests/vdb/pgvector/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/pyvastbase/__init__.py b/api/tests/integration_tests/vdb/pyvastbase/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/qdrant/__init__.py b/api/tests/integration_tests/vdb/qdrant/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/tablestore/__init__.py b/api/tests/integration_tests/vdb/tablestore/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/tcvectordb/__init__.py b/api/tests/integration_tests/vdb/tcvectordb/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/tidb_vector/__init__.py b/api/tests/integration_tests/vdb/tidb_vector/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/upstash/__init__.py b/api/tests/integration_tests/vdb/upstash/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/vikingdb/__init__.py b/api/tests/integration_tests/vdb/vikingdb/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/vdb/weaviate/__init__.py b/api/tests/integration_tests/vdb/weaviate/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index c0143faa85..a9a2617bae 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -1,12 +1,11 @@ from unittest.mock import MagicMock -from graphon.model_runtime.entities.model_entities import ModelType - from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py deleted file mode 100644 index 487178ff58..0000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor - -CODE_LANGUAGE = "unsupported_language" - - -def test_unsupported_with_code_template(): - with pytest.raises(CodeExecutionError) as e: - CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) - assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py deleted file mode 100644 index c8eb9ec3e4..0000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ /dev/null @@ -1,95 +0,0 @@ -import base64 - -from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer - -CODE_LANGUAGE = CodeLanguage.JINJA2 - - -def test_jinja2(): - """Test basic Jinja2 template rendering.""" - template = "Hello {{template}}" - # Template must be base64 encoded to match the new safe embedding approach - template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8") - inputs = base64.b64encode(b'{"template": "World"}').decode("utf-8") - code = ( - Jinja2TemplateTransformer.get_runner_script() - .replace(Jinja2TemplateTransformer._template_b64_placeholder, template_b64) - .replace(Jinja2TemplateTransformer._inputs_placeholder, inputs) - ) - result = CodeExecutor.execute_code( - language=CODE_LANGUAGE, preload=Jinja2TemplateTransformer.get_preload_script(), code=code - ) - assert result == "<>Hello World<>\n" - - -def test_jinja2_with_code_template(): - """Test template rendering via the high-level workflow API.""" - result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code="Hello {{template}}", inputs={"template": "World"} - ) - assert result == {"result": "Hello World"} - - -def test_jinja2_get_runner_script(): - """Test that runner script contains required placeholders.""" - runner_script = Jinja2TemplateTransformer.get_runner_script() - assert runner_script.count(Jinja2TemplateTransformer._template_b64_placeholder) == 1 - assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 - assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 - - -def test_jinja2_template_with_special_characters(): - """ - Test that templates with special characters (quotes, newlines) render correctly. - This is a regression test for issue #26818 where textarea pre-fill values - containing special characters would break template rendering. - """ - # Template with triple quotes, single quotes, double quotes, and newlines - template = """ - - - -

Status: "{{ status }}"

-
'''code block'''
- -""" - inputs = {"task": {"Task ID": "TASK-123", "Issues": "Line 1\nLine 2\nLine 3"}, "status": "completed"} - - result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs) - - # Verify the template rendered correctly with all special characters - output = result["result"] - assert 'value="TASK-123"' in output - assert "" in output - assert 'Status: "completed"' in output - assert "'''code block'''" in output - - -def test_jinja2_template_with_html_textarea_prefill(): - """ - Specific test for HTML textarea with Jinja2 variable pre-fill. - Verifies fix for issue #26818. - """ - template = "" - notes_content = "This is a multi-line note.\nWith special chars: 'single' and \"double\" quotes." - inputs = {"notes": notes_content} - - result = CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code=template, inputs=inputs) - - expected_output = f"" - assert result["result"] == expected_output - - -def test_jinja2_assemble_runner_script_encodes_template(): - """Test that assemble_runner_script properly base64 encodes the template.""" - template = "Hello {{ name }}!" - inputs = {"name": "World"} - - script = Jinja2TemplateTransformer.assemble_runner_script(template, inputs) - - # The template should be base64 encoded in the script - template_b64 = base64.b64encode(template.encode("utf-8")).decode("utf-8") - assert template_b64 in script - # The raw template should NOT appear in the script (it's encoded) - assert "Hello {{ name }}!" not in script diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py deleted file mode 100644 index 25af312afa..0000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ /dev/null @@ -1,36 +0,0 @@ -from textwrap import dedent - -from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer - -CODE_LANGUAGE = CodeLanguage.PYTHON3 - - -def test_python3_plain(): - code = 'print("Hello World")' - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) - assert result == "Hello World\n" - - -def test_python3_json(): - code = dedent(""" - import json - print(json.dumps({'Hello': 'World'})) - """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) - assert result == '{"Hello": "World"}\n' - - -def test_python3_with_code_template(): - result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"} - ) - assert result == {"result": "HelloWorld"} - - -def test_python3_get_runner_script(): - runner_script = Python3TemplateTransformer.get_runner_script() - assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1 - assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1 - assert runner_script.count(Python3TemplateTransformer._result_tag) == 2 diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 4f41396c22..aaa6092993 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,18 @@ import time import uuid import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.node_events import NodeRunResult -from graphon.nodes.code.code_node import CodeNode -from graphon.nodes.code.limits import CodeNodeLimits -from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import NodeRunResult +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeNodeData +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.code_executor",) @@ -64,8 +65,8 @@ def init_code_node(code_config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = CodeNode( - id=str(uuid.uuid4()), - config=code_config, + node_id=str(uuid.uuid4()), + config=CodeNodeData.model_validate(code_config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, code_executor=node_factory._code_executor, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index b1f937e738..b9f7b9575b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,11 +3,6 @@ import uuid from urllib.parse import urlencode import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.graph import Graph -from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -16,6 +11,11 @@ from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig, HttpRequestNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params pytest_plugins = ("tests.integration_tests.workflow.nodes.__mock.http",) @@ -75,8 +75,8 @@ def init_http_node(config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, @@ -192,6 +192,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" + from core.workflow.system_variables import build_system_variables from graphon.enums import BuiltinNodeTypes from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -202,8 +203,6 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool - from core.workflow.system_variables import build_system_variables - # Create variable pool variable_pool = VariablePool( system_variables=build_system_variables(user_id="test", files=[]), @@ -724,8 +723,8 @@ def test_nested_object_variable_selector(setup_http_mock): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( - id=str(uuid.uuid4()), - config=graph_config["nodes"][1], + node_id=str(uuid.uuid4()), + config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index f0f3fcead1..3eead70163 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,20 +4,20 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.entities import LLMNodeData from graphon.nodes.llm.file_saver import LLMFileSaver from graphon.nodes.llm.node import LLMNode from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.nodes.protocols import HttpClientProtocol from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.llm_generator.output_parser.structured_output import _parse_structured_output -from core.model_manager import ModelInstance -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" @@ -76,8 +76,8 @@ def init_llm_node(config: dict) -> LLMNode: llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=LLMNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index fe512c2585..f2eabb86c3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,17 +3,17 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory -from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_runtime import DifyPromptMessageSerializer from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -70,8 +70,8 @@ def init_parameter_extractor_node(config: dict, memory=None): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = ParameterExtractorNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ParameterExtractorNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 2d728569be..e2e0723fb8 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,15 +1,15 @@ import time import uuid -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.template_rendering import TemplateRenderError - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -87,8 +87,8 @@ def test_execute_template_transform(): assert graph is not None node = TemplateTransformNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=TemplateTransformNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, jinja2_template_renderer=_SimpleJinja2Renderer(), diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 750ced7075..a8e9422c1e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,18 +2,18 @@ import time import uuid from unittest.mock import MagicMock, patch -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.node_events import StreamCompletedEvent -from graphon.nodes.protocols import ToolFileManagerProtocol -from graphon.nodes.tool.tool_node import ToolNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.entities import ToolNodeData +from graphon.nodes.tool.tool_node import ToolNode +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -61,8 +61,8 @@ def init_tool_node(config: dict): tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) node = ToolNode( - id=str(uuid.uuid4()), - config=config, + node_id=str(uuid.uuid4()), + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index ef74893f07..b4482674da 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -369,7 +369,7 @@ def _create_app_with_containers() -> Flask: # Create and configure the Flask application logger.info("Initializing Flask application...") - app = create_app() + sio_app, app = create_app() logger.info("Flask application created successfully") # Initialize database schema @@ -433,7 +433,7 @@ def flask_app_with_containers(set_up_containers_and_env) -> Flask: @pytest.fixture -def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]: +def flask_req_ctx_with_containers(flask_app_with_containers: Flask) -> Generator[None, None, None]: """ Request context fixture for containerized Flask application. @@ -454,7 +454,7 @@ def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, @pytest.fixture -def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]: +def test_client_with_containers(flask_app_with_containers: Flask) -> Generator[FlaskClient, None, None]: """ Test client fixture for containerized Flask application. @@ -475,7 +475,7 @@ def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskCli @pytest.fixture -def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]: +def db_session_with_containers(flask_app_with_containers: Flask) -> Generator[Session, None, None]: """ Database session fixture for containerized testing. diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index c3a861c3e1..bb737754a1 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -69,7 +70,7 @@ def _unwrap(func): class TestCompletionEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_completion_create_payload(self): @@ -86,7 +87,7 @@ class TestCompletionEndpoints: ) assert payload.query == "hi" - def test_completion_api_success(self, app, monkeypatch): + def test_completion_api_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -116,7 +117,7 @@ class TestCompletionEndpoints: assert resp == {"result": {"text": "ok"}} - def test_completion_api_conversation_not_exists(self, app, monkeypatch): + def test_completion_api_conversation_not_exists(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -142,7 +143,7 @@ class TestCompletionEndpoints: with pytest.raises(NotFound): method(app_model=MagicMock(id="app-1")) - def test_completion_api_provider_not_initialized(self, app, monkeypatch): + def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -166,7 +167,7 @@ class TestCompletionEndpoints: with pytest.raises(completion_module.ProviderNotInitializeError): method(app_model=MagicMock(id="app-1")) - def test_completion_api_quota_exceeded(self, app, monkeypatch): + def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -193,10 +194,10 @@ class TestCompletionEndpoints: class TestAppEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = app_module.AppApi() method = _unwrap(api.put) payload = { @@ -234,10 +235,39 @@ class TestAppEndpoints: } ) + def test_app_icon_post_should_forward_icon_type(self, app: Flask, monkeypatch: pytest.MonkeyPatch): + api = app_module.AppIconApi() + method = _unwrap(api.post) + payload = { + "icon": "https://example.com/icon.png", + "icon_type": "image", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app_icon.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetail, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1/icon", method="POST", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace()) + + assert response == {"id": "app-1"} + assert app_service.update_app_icon.call_args.args[1:] == ( + payload["icon"], + payload["icon_background"], + app_module.IconType.IMAGE, + ) + class TestOpsTraceEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_ops_trace_query_basic(self): @@ -248,7 +278,7 @@ class TestOpsTraceEndpoints: payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"}) assert payload.tracing_config["api_key"] == "k" - def test_trace_app_config_get_empty(self, app, monkeypatch): + def test_trace_app_config_get_empty(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.get) @@ -263,7 +293,7 @@ class TestOpsTraceEndpoints: assert result == {"has_not_configured": True} - def test_trace_app_config_post_invalid(self, app, monkeypatch): + def test_trace_app_config_post_invalid(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.post) @@ -280,7 +310,7 @@ class TestOpsTraceEndpoints: with pytest.raises(BadRequest): method(app_id="app-1") - def test_trace_app_config_delete_not_found(self, app, monkeypatch): + def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.delete) @@ -297,7 +327,7 @@ class TestOpsTraceEndpoints: class TestSiteEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_site_response_structure(self): @@ -308,11 +338,26 @@ class TestSiteEndpoints: payload = AppSiteUpdatePayload(default_language="en-US") assert payload.default_language == "en-US" - def test_app_site_update_post(self, app, monkeypatch): + def test_app_site_update_post(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = site_module.AppSite() method = _unwrap(api.post) site = MagicMock() + site.app_id = "app-1" + site.code = "test-code" + site.title = "My Site" + site.icon = None + site.icon_background = None + site.description = "Test site" + site.default_language = "en-US" + site.customize_domain = None + site.copyright = None + site.privacy_policy = None + site.custom_disclaimer = "" + site.customize_token_strategy = "not_allow" + site.prompt_public = False + site.show_workflow_steps = True + site.use_icon_as_answer_icon = False monkeypatch.setattr( site_module.db, "session", @@ -328,13 +373,29 @@ class TestSiteEndpoints: with app.test_request_context("/", json={"title": "My Site"}): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is site + assert isinstance(result, dict) + assert result["title"] == "My Site" - def test_app_site_access_token_reset(self, app, monkeypatch): + def test_app_site_access_token_reset(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = site_module.AppSiteAccessTokenReset() method = _unwrap(api.post) site = MagicMock() + site.app_id = "app-1" + site.code = "old-code" + site.title = "My Site" + site.icon = None + site.icon_background = None + site.description = None + site.default_language = "en-US" + site.customize_domain = None + site.copyright = None + site.privacy_policy = None + site.custom_disclaimer = "" + site.customize_token_strategy = "not_allow" + site.prompt_public = False + site.show_workflow_steps = True + site.use_icon_as_answer_icon = False monkeypatch.setattr( site_module.db, "session", @@ -351,7 +412,8 @@ class TestSiteEndpoints: with app.test_request_context("/"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is site + assert isinstance(result, dict) + assert result["access_token"] == "code" class TestWorkflowEndpoints: @@ -366,7 +428,7 @@ class TestWorkflowEndpoints: class TestWorkflowAppLogEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_app_log_query(self): @@ -377,7 +439,7 @@ class TestWorkflowAppLogEndpoints: query = WorkflowAppLogQuery(detail="true") assert query.detail is True - def test_workflow_app_log_api_get(self, app, monkeypatch): + def test_workflow_app_log_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_app_log_module.WorkflowAppLogApi() method = _unwrap(api.get) @@ -400,7 +462,7 @@ class TestWorkflowAppLogEndpoints: monkeypatch.setattr(workflow_app_log_module, "sessionmaker", DummySessionMaker) def fake_get_paginate(self, **_kwargs): - return {"items": [], "total": 0} + return {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} monkeypatch.setattr( workflow_app_log_module.WorkflowAppService, @@ -411,19 +473,19 @@ class TestWorkflowAppLogEndpoints: with app.test_request_context("/?page=1&limit=20"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result == {"items": [], "total": 0} + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} class TestWorkflowDraftVariableEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_variable_creation(self): payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test") assert payload.name == "var1" - def test_workflow_variable_collection_get(self, app, monkeypatch): + def test_workflow_variable_collection_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_draft_variable_module.WorkflowVariableCollectionApi() method = _unwrap(api.get) @@ -468,7 +530,7 @@ class TestWorkflowDraftVariableEndpoints: class TestWorkflowStatisticEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_statistic_time_range(self): @@ -480,7 +542,7 @@ class TestWorkflowStatisticEndpoints: assert query.start is None assert query.end is None - def test_workflow_daily_runs_statistic(self, app, monkeypatch): + def test_workflow_daily_runs_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr( workflow_statistic_module.DifyAPIRepositoryFactory, @@ -506,7 +568,7 @@ class TestWorkflowStatisticEndpoints: assert response.get_json() == {"data": [{"date": "2024-01-01"}]} - def test_workflow_daily_terminals_statistic(self, app, monkeypatch): + def test_workflow_daily_terminals_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr( workflow_statistic_module.DifyAPIRepositoryFactory, @@ -537,7 +599,7 @@ class TestWorkflowStatisticEndpoints: class TestWorkflowTriggerEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_webhook_trigger_payload(self): @@ -547,7 +609,7 @@ class TestWorkflowTriggerEndpoints: enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True) assert enable_payload.enable_trigger is True - def test_webhook_trigger_api_get(self, app, monkeypatch): + def test_webhook_trigger_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_trigger_module.WebhookTriggerApi() method = _unwrap(api.get) @@ -576,7 +638,8 @@ class TestWorkflowTriggerEndpoints: with app.test_request_context("/?node_id=node-1"): result = method(app_model=SimpleNamespace(id="app-1")) - assert result is trigger + assert isinstance(result, dict) + assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys()) class TestWrapsEndpoints: diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index d8c6821f8d..bcb6e41ef7 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from flask import Flask from controllers.console.app import app_import as app_import_module from services.app_dsl_service import ImportStatus @@ -36,10 +37,10 @@ def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: class TestAppImportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_post_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -57,7 +58,7 @@ class TestAppImportApi: assert status == 400 assert response["status"] == ImportStatus.FAILED - def test_import_post_returns_pending_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_returns_pending_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -75,7 +76,7 @@ class TestAppImportApi: assert status == 202 assert response["status"] == ImportStatus.PENDING - def test_import_post_updates_webapp_auth_when_enabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_updates_webapp_auth_when_enabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -96,13 +97,63 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED + def test_import_post_commits_session_on_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.commit.assert_called_once_with() + fake_session.rollback.assert_not_called() + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + def test_import_post_rolls_back_session_on_failure(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.rollback.assert_called_once_with() + fake_session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + class TestAppImportConfirmApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_confirm_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_confirm_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportConfirmApi() method = _unwrap(api.post) @@ -122,10 +173,10 @@ class TestAppImportConfirmApi: class TestAppImportCheckDependenciesApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_check_dependencies_returns_result(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_check_dependencies_returns_result(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportCheckDependenciesApi() method = _unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 5cc458fe2e..5a22f81a69 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,15 +4,15 @@ import json import uuid from flask.testing import FlaskClient -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin -from models.account import AccountStatus, TenantAccountRole +from models.account import AccountStatus, TenantAccountRole, TenantStatus from models.enums import ConversationFromSource, CreatorUserRole from models.model import App, AppMode, Conversation, Message from models.workflow import WorkflowRun @@ -30,7 +30,7 @@ def _create_account_and_tenant(db_session: Session) -> tuple[Account, Tenant]: db_session.add(account) db_session.commit() - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) db_session.add(tenant) db_session.commit() diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py new file mode 100644 index 0000000000..fad0b8b10e --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_conversation_read_timestamp.py @@ -0,0 +1,73 @@ +from datetime import datetime +from unittest.mock import patch + +import pytest +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from controllers.console.app.conversation import _get_conversation +from models.enums import ConversationFromSource +from models.model import AppMode, Conversation +from tests.test_containers_integration_tests.controllers.console.helpers import ( + create_console_account_and_tenant, + create_console_app, +) + + +def test_get_conversation_mark_read_keeps_updated_at_unchanged( + db_session_with_containers: Session, +): + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + original_updated_at = datetime(2026, 2, 8, 0, 0, 0) + conversation = Conversation( + app_id=app.id, + name="read timestamp test", + inputs={}, + status="normal", + mode=AppMode.CHAT, + from_source=ConversationFromSource.CONSOLE, + from_account_id=account.id, + updated_at=original_updated_at, + ) + db_session_with_containers.add(conversation) + db_session_with_containers.commit() + + read_at = datetime(2026, 2, 9, 0, 0, 0) + + with ( + patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, tenant.id), + autospec=True, + ), + patch( + "controllers.console.app.conversation.naive_utc_now", + return_value=read_at, + autospec=True, + ), + ): + loaded = _get_conversation(app, conversation.id) + + db_session_with_containers.refresh(conversation) + + assert loaded.id == conversation.id + assert conversation.read_at == read_at + assert conversation.read_account_id == account.id + assert conversation.updated_at == original_updated_at + + +def test_get_conversation_raises_not_found_for_missing_conversation( + db_session_with_containers: Session, +): + account, tenant = create_console_account_and_tenant(db_session_with_containers) + app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.CHAT) + + with patch( + "controllers.console.app.conversation.current_account_with_tenant", + return_value=(account, tenant.id), + autospec=True, + ): + with pytest.raises(NotFound): + _get_conversation(app, "00000000-0000-0000-0000-000000000000") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py index 6b51ec98bc..eff6dd789d 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_message.py @@ -148,14 +148,18 @@ def test_chat_message_list_success( account.id, created_at_offset_seconds=1, ) + # Capture IDs before the HTTP request detaches ORM instances from the session + app_id = app.id + conversation_id = conversation.id + second_id = second.id with patch( "controllers.console.app.message.attach_message_extra_contents", side_effect=_attach_message_extra_contents, ): response = test_client_with_containers.get( - f"/console/api/apps/{app.id}/chat-messages", - query_string={"conversation_id": conversation.id, "limit": 1}, + f"/console/api/apps/{app_id}/chat-messages", + query_string={"conversation_id": conversation_id, "limit": 1}, headers=authenticate_console_client(test_client_with_containers, account), ) @@ -165,7 +169,7 @@ def test_chat_message_list_success( assert payload["limit"] == 1 assert payload["has_more"] is True assert len(payload["data"]) == 1 - assert payload["data"][0]["id"] == second.id + assert payload["data"][0]["id"] == second_id def test_message_feedback_not_found( diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py index 8ddf867370..290be87697 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -3,12 +3,12 @@ import uuid from flask.testing import FlaskClient -from graphon.variables.segments import StringSegment from sqlalchemy import select from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from factories.variable_factory import segment_to_variable +from graphon.variables.segments import StringSegment from models import Workflow from models.model import AppMode from models.workflow import WorkflowDraftVariable diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 879c337319..1fcce9ca44 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.email_register import ( EmailRegisterCheckApi, @@ -16,7 +17,7 @@ from services.account_service import AccountService @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): return flask_app_with_containers @@ -33,7 +34,7 @@ class TestEmailRegisterSendEmailApi: mock_is_freeze, mock_send_mail, mock_get_account, - app, + app: Flask, ): mock_send_mail.return_value = "token-123" mock_is_freeze.return_value = False @@ -75,7 +76,7 @@ class TestEmailRegisterCheckApi: mock_revoke, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_rate_limit_check.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"} @@ -120,7 +121,7 @@ class TestEmailRegisterResetApi: mock_create_account, mock_login, mock_reset_login_rate, - app, + app: Flask, ): mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} mock_create_account.return_value = MagicMock() @@ -158,7 +159,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 7b7393dade..014c1588fe 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.forgot_password import ( ForgotPasswordCheckApi, @@ -16,7 +17,7 @@ from services.account_service import AccountService @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): return flask_app_with_containers @@ -31,7 +32,7 @@ class TestForgotPasswordSendEmailApi: mock_is_ip_limit, mock_send_email, mock_get_account, - app, + app: Flask, ): mock_account = MagicMock() mock_get_account.return_value = mock_account @@ -80,7 +81,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_rate_limit_check.return_value = False mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"} @@ -113,19 +114,22 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") def test_reset_fetches_account_with_original_email( self, mock_get_reset_data, mock_revoke_token, + mock_db, mock_get_account, mock_update_account, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account wraps_features = SimpleNamespace(enable_email_password_login=True) with ( @@ -161,7 +165,10 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index a2f1328579..55b6a919d8 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, @@ -21,7 +22,7 @@ from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.mark.parametrize( @@ -65,7 +66,7 @@ class TestOAuthLogin: return OAuthLogin() @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -89,7 +90,7 @@ class TestOAuthLogin: mock_redirect, mock_get_providers, resource, - app, + app: Flask, mock_oauth_provider, invite_token, expected_token, @@ -130,7 +131,7 @@ class TestOAuthCallback: return OAuthCallback() @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -164,7 +165,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): mock_config.CONSOLE_WEB_URL = "http://localhost:3000" @@ -217,7 +218,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): mock_config.CONSOLE_WEB_URL = "http://localhost:3000" @@ -261,7 +262,7 @@ class TestOAuthCallback: mock_tenant_service, mock_account_service, resource, - app, + app: Flask, oauth_setup, account_status, expected_redirect, @@ -300,7 +301,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): mock_get_providers.return_value = {"github": oauth_setup["provider"]} @@ -336,7 +337,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): """Defensive test for CLOSED account status handling in OAuth callback. @@ -394,7 +395,7 @@ class TestOAuthCallback: class TestAccountGeneration: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -437,7 +438,10 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + with patch("services.account_service.session_factory") as mock_factory: + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 @@ -462,7 +466,7 @@ class TestAccountGeneration: mock_register_service, mock_feature_service, mock_get_account, - app, + app: Flask, user_info, mock_account, allow_register, @@ -501,7 +505,7 @@ class TestAccountGeneration: mock_register_service, mock_feature_service, mock_get_account, - app, + app: Flask, ): user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com") mock_feature_service.get_system_features.return_value.is_allow_register = True @@ -526,7 +530,7 @@ class TestAccountGeneration: mock_feature_service, mock_tenant_service, mock_get_account, - app, + app: Flask, user_info, mock_account, ): diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 8f9db287e3..d017e8f2bd 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -25,7 +26,7 @@ class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -46,7 +47,7 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - app, + app: Flask, mock_account, ): # Arrange @@ -68,7 +69,7 @@ class TestForgotPasswordSendEmailApi: mock_send_email.assert_called_once() @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app: Flask): """ Test password reset email blocked by IP rate limit. @@ -104,7 +105,7 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - app, + app: Flask, mock_account, language_input, expected_language, @@ -138,7 +139,7 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @@ -153,7 +154,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - app, + app: Flask, ): """ Test successful verification code validation. @@ -200,7 +201,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} @@ -221,7 +222,7 @@ class TestForgotPasswordCheckApi: mock_reset_rate_limit.assert_called_once_with("user@example.com") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app: Flask): """ Test code verification blocked by rate limit. @@ -244,7 +245,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with invalid token. @@ -267,7 +268,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with mismatched email. @@ -292,7 +293,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with incorrect code. @@ -321,7 +322,7 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -335,14 +336,16 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") + @patch("controllers.console.auth.forgot_password.db") @patch("controllers.console.auth.forgot_password.TenantService.get_join_tenants") def test_reset_password_success( self, mock_get_tenants, + mock_db, mock_get_account, mock_revoke_token, mock_get_data, - app, + app: Flask, mock_account, ): """ @@ -356,6 +359,7 @@ class TestForgotPasswordResetApi: # Arrange mock_get_data.return_value = {"email": "test@example.com", "phase": "reset"} mock_get_account.return_value = mock_account + mock_db.session.merge.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] # Act @@ -372,7 +376,7 @@ class TestForgotPasswordResetApi: mock_revoke_token.assert_called_once_with("valid_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, app): + def test_reset_password_mismatch(self, mock_get_data, app: Flask): """ Test password reset with mismatched passwords. @@ -394,7 +398,7 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, app): + def test_reset_password_invalid_token(self, mock_get_data, app: Flask): """ Test password reset with invalid token. @@ -415,7 +419,7 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, app): + def test_reset_password_wrong_phase(self, mock_get_data, app: Flask): """ Test password reset with token not in reset phase. @@ -439,7 +443,7 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app: Flask): """ Test password reset for non-existent account. diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index d5ae95dfb7..7aa4aff1cc 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from controllers.console import console_ns @@ -26,10 +27,10 @@ def unwrap(func): class TestPipelineTemplateListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = PipelineTemplateListApi() method = unwrap(api.get) @@ -50,10 +51,10 @@ class TestPipelineTemplateListApi: class TestPipelineTemplateDetailApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = PipelineTemplateDetailApi() method = unwrap(api.get) @@ -74,7 +75,7 @@ class TestPipelineTemplateDetailApi: assert status == 200 assert response == template - def test_get_returns_404_when_template_not_found(self, app): + def test_get_returns_404_when_template_not_found(self, app: Flask): api = PipelineTemplateDetailApi() method = unwrap(api.get) @@ -93,7 +94,7 @@ class TestPipelineTemplateDetailApi: assert status == 404 assert "error" in response - def test_get_returns_404_for_customized_type_not_found(self, app): + def test_get_returns_404_for_customized_type_not_found(self, app: Flask): api = PipelineTemplateDetailApi() method = unwrap(api.get) @@ -115,10 +116,10 @@ class TestPipelineTemplateDetailApi: class TestCustomizedPipelineTemplateApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = CustomizedPipelineTemplateApi() method = unwrap(api.patch) @@ -140,7 +141,7 @@ class TestCustomizedPipelineTemplateApi: update_mock.assert_called_once() assert response == 200 - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = CustomizedPipelineTemplateApi() method = unwrap(api.delete) @@ -155,7 +156,7 @@ class TestCustomizedPipelineTemplateApi: delete_mock.assert_called_once_with("tpl-1") assert response == 200 - def test_post_success(self, app, db_session_with_containers: Session): + def test_post_success(self, app: Flask, db_session_with_containers: Session): api = CustomizedPipelineTemplateApi() method = unwrap(api.post) @@ -182,7 +183,7 @@ class TestCustomizedPipelineTemplateApi: assert status == 200 assert response == {"data": "yaml-data"} - def test_post_template_not_found(self, app): + def test_post_template_not_found(self, app: Flask): api = CustomizedPipelineTemplateApi() method = unwrap(api.post) @@ -193,10 +194,10 @@ class TestCustomizedPipelineTemplateApi: class TestPublishCustomizedPipelineTemplateApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = PublishCustomizedPipelineTemplateApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py index 64e3de2ca3..7624c1150f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden import services @@ -24,13 +25,13 @@ def unwrap(func): class TestCreateRagPipelineDatasetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _valid_payload(self): return {"yaml_content": "name: test"} - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -58,7 +59,7 @@ class TestCreateRagPipelineDatasetApi: assert status == 201 assert response == import_info - def test_post_forbidden_non_editor(self, app): + def test_post_forbidden_non_editor(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -76,7 +77,7 @@ class TestCreateRagPipelineDatasetApi: with pytest.raises(Forbidden): method(api) - def test_post_dataset_name_duplicate(self, app): + def test_post_dataset_name_duplicate(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -101,7 +102,7 @@ class TestCreateRagPipelineDatasetApi: with pytest.raises(DatasetNameDuplicateError): method(api) - def test_post_invalid_payload(self, app): + def test_post_invalid_payload(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -122,10 +123,10 @@ class TestCreateRagPipelineDatasetApi: class TestCreateEmptyRagPipelineDatasetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) @@ -152,7 +153,7 @@ class TestCreateEmptyRagPipelineDatasetApi: assert status == 201 assert response == {"id": "ds-1"} - def test_post_forbidden_non_editor(self, app): + def test_post_forbidden_non_editor(self, app: Flask): api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py index cb67892878..44eb5c336c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console import console_ns from controllers.console.datasets.rag_pipeline.rag_pipeline_import import ( @@ -25,7 +26,7 @@ def unwrap(func): class TestRagPipelineImportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _payload(self, mode="create"): @@ -35,7 +36,7 @@ class TestRagPipelineImportApi: "name": "Test", } - def test_post_success_200(self, app): + def test_post_success_200(self, app: Flask): api = RagPipelineImportApi() method = unwrap(api.post) @@ -65,7 +66,7 @@ class TestRagPipelineImportApi: assert status == 200 assert response == {"status": "success"} - def test_post_failed_400(self, app): + def test_post_failed_400(self, app: Flask): api = RagPipelineImportApi() method = unwrap(api.post) @@ -95,7 +96,7 @@ class TestRagPipelineImportApi: assert status == 400 assert response == {"status": "failed"} - def test_post_pending_202(self, app): + def test_post_pending_202(self, app: Flask): api = RagPipelineImportApi() method = unwrap(api.post) @@ -128,10 +129,10 @@ class TestRagPipelineImportApi: class TestRagPipelineImportConfirmApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_confirm_success(self, app): + def test_confirm_success(self, app: Flask): api = RagPipelineImportConfirmApi() method = unwrap(api.post) @@ -159,7 +160,7 @@ class TestRagPipelineImportConfirmApi: assert status == 200 assert response == {"ok": True} - def test_confirm_failed(self, app): + def test_confirm_failed(self, app: Flask): api = RagPipelineImportConfirmApi() method = unwrap(api.post) @@ -190,10 +191,10 @@ class TestRagPipelineImportConfirmApi: class TestRagPipelineImportCheckDependenciesApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = RagPipelineImportCheckDependenciesApi() method = unwrap(api.get) @@ -219,10 +220,10 @@ class TestRagPipelineImportCheckDependenciesApi: class TestRagPipelineExportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_with_include_secret(self, app): + def test_get_with_include_secret(self, app: Flask): api = RagPipelineExportApi() method = unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index c1f3122c2b..c17a83cad3 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound @@ -45,10 +46,10 @@ def unwrap(func): class TestDraftWorkflowApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_draft_success(self, app): + def test_get_draft_success(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.get) @@ -68,7 +69,7 @@ class TestDraftWorkflowApi: result = method(api, pipeline) assert result == workflow - def test_get_draft_not_exist(self, app): + def test_get_draft_not_exist(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.get) @@ -86,7 +87,7 @@ class TestDraftWorkflowApi: with pytest.raises(DraftWorkflowNotExist): method(api, pipeline) - def test_sync_hash_not_match(self, app): + def test_sync_hash_not_match(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.post) @@ -111,7 +112,7 @@ class TestDraftWorkflowApi: with pytest.raises(DraftWorkflowNotSync): method(api, pipeline) - def test_sync_invalid_text_plain(self, app): + def test_sync_invalid_text_plain(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.post) @@ -128,7 +129,7 @@ class TestDraftWorkflowApi: response, status = method(api, pipeline) assert status == 400 - def test_restore_published_workflow_to_draft_success(self, app): + def test_restore_published_workflow_to_draft_success(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -155,7 +156,7 @@ class TestDraftWorkflowApi: assert result["result"] == "success" assert result["hash"] == "restored-hash" - def test_restore_published_workflow_to_draft_not_found(self, app): + def test_restore_published_workflow_to_draft_not_found(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -179,7 +180,7 @@ class TestDraftWorkflowApi: with pytest.raises(NotFound): method(api, pipeline, "published-workflow") - def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app): + def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -211,10 +212,10 @@ class TestDraftWorkflowApi: class TestDraftRunNodes: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_iteration_node_success(self, app): + def test_iteration_node_success(self, app: Flask): api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) @@ -240,7 +241,7 @@ class TestDraftRunNodes: result = method(api, pipeline, "node") assert result == {"ok": True} - def test_iteration_node_conversation_not_exists(self, app): + def test_iteration_node_conversation_not_exists(self, app: Flask): api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) @@ -262,7 +263,7 @@ class TestDraftRunNodes: with pytest.raises(NotFound): method(api, pipeline, "node") - def test_loop_node_success(self, app): + def test_loop_node_success(self, app: Flask): api = RagPipelineDraftRunLoopNodeApi() method = unwrap(api.post) @@ -290,10 +291,10 @@ class TestDraftRunNodes: class TestPipelineRunApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_draft_run_success(self, app): + def test_draft_run_success(self, app: Flask): api = DraftRagPipelineRunApi() method = unwrap(api.post) @@ -325,7 +326,7 @@ class TestPipelineRunApis: ): assert method(api, pipeline) == {"ok": True} - def test_draft_run_rate_limit(self, app): + def test_draft_run_rate_limit(self, app: Flask): api = DraftRagPipelineRunApi() method = unwrap(api.post) @@ -356,10 +357,10 @@ class TestPipelineRunApis: class TestDraftNodeRun: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_execution_not_found(self, app): + def test_execution_not_found(self, app: Flask): api = RagPipelineDraftNodeRunApi() method = unwrap(api.post) @@ -387,10 +388,10 @@ class TestDraftNodeRun: class TestPublishedPipelineApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_publish_success(self, app, db_session_with_containers: Session): + def test_publish_success(self, app: Flask, db_session_with_containers: Session): from models.dataset import Pipeline api = PublishedRagPipelineApi() @@ -436,10 +437,10 @@ class TestPublishedPipelineApis: class TestMiscApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_task_stop(self, app): + def test_task_stop(self, app: Flask): api = RagPipelineTaskStopApi() method = unwrap(api.post) @@ -460,7 +461,7 @@ class TestMiscApis: stop_mock.assert_called_once() assert result["result"] == "success" - def test_transform_forbidden(self, app): + def test_transform_forbidden(self, app: Flask): api = RagPipelineTransformApi() method = unwrap(api.post) @@ -476,7 +477,7 @@ class TestMiscApis: with pytest.raises(Forbidden): method(api, "ds1") - def test_recommended_plugins(self, app): + def test_recommended_plugins(self, app: Flask): api = RagPipelineRecommendedPluginApi() method = unwrap(api.get) @@ -496,10 +497,10 @@ class TestMiscApis: class TestPublishedRagPipelineRunApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_published_run_success(self, app): + def test_published_run_success(self, app: Flask): api = PublishedRagPipelineRunApi() method = unwrap(api.post) @@ -533,7 +534,7 @@ class TestPublishedRagPipelineRunApi: result = method(api, pipeline) assert result == {"ok": True} - def test_published_run_rate_limit(self, app): + def test_published_run_rate_limit(self, app: Flask): api = PublishedRagPipelineRunApi() method = unwrap(api.post) @@ -565,10 +566,10 @@ class TestPublishedRagPipelineRunApi: class TestDefaultBlockConfigApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_block_config_success(self, app): + def test_get_block_config_success(self, app: Flask): api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) @@ -587,7 +588,7 @@ class TestDefaultBlockConfigApi: result = method(api, pipeline, "llm") assert result == {"k": "v"} - def test_get_block_config_invalid_json(self, app): + def test_get_block_config_invalid_json(self, app: Flask): api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) @@ -600,10 +601,10 @@ class TestDefaultBlockConfigApi: class TestPublishedAllRagPipelineApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_published_workflows_success(self, app): + def test_get_published_workflows_success(self, app: Flask): api = PublishedAllRagPipelineApi() method = unwrap(api.get) @@ -629,7 +630,7 @@ class TestPublishedAllRagPipelineApi: assert result["items"] == [{"id": "w1"}] assert result["has_more"] is False - def test_get_published_workflows_forbidden(self, app): + def test_get_published_workflows_forbidden(self, app: Flask): api = PublishedAllRagPipelineApi() method = unwrap(api.get) @@ -649,10 +650,10 @@ class TestPublishedAllRagPipelineApi: class TestRagPipelineByIdApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.patch) @@ -682,7 +683,7 @@ class TestRagPipelineByIdApi: assert result == workflow - def test_patch_no_fields(self, app): + def test_patch_no_fields(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.patch) @@ -700,7 +701,7 @@ class TestRagPipelineByIdApi: result, status = method(api, pipeline, "w1") assert status == 400 - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.delete) @@ -720,7 +721,7 @@ class TestRagPipelineByIdApi: workflow_service.delete_workflow.assert_called_once() assert result == (None, 204) - def test_delete_active_workflow_rejected(self, app): + def test_delete_active_workflow_rejected(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.delete) @@ -733,10 +734,10 @@ class TestRagPipelineByIdApi: class TestRagPipelineWorkflowLastRunApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_last_run_success(self, app): + def test_last_run_success(self, app: Flask): api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) @@ -758,7 +759,7 @@ class TestRagPipelineWorkflowLastRunApi: result = method(api, pipeline, "node1") assert result == node_exec - def test_last_run_not_found(self, app): + def test_last_run_not_found(self, app: Flask): api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) @@ -780,10 +781,10 @@ class TestRagPipelineWorkflowLastRunApi: class TestRagPipelineDatasourceVariableApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_set_datasource_variables_success(self, app): + def test_set_datasource_variables_success(self, app: Flask): api = RagPipelineDatasourceVariableApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py index 1c4c6a899f..b59009f7c4 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, PropertyMock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.console.datasets import data_source @@ -51,10 +52,10 @@ def mock_engine(): class TestDataSourceApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, patch_tenant): + def test_get_success(self, app: Flask, patch_tenant): api = DataSourceApi() method = unwrap(api.get) @@ -78,7 +79,7 @@ class TestDataSourceApi: assert status == 200 assert response["data"][0]["is_bound"] is True - def test_get_no_bindings(self, app, patch_tenant): + def test_get_no_bindings(self, app: Flask, patch_tenant): api = DataSourceApi() method = unwrap(api.get) @@ -94,7 +95,7 @@ class TestDataSourceApi: assert status == 200 assert response["data"] == [] - def test_patch_enable_binding(self, app, patch_tenant, mock_engine): + def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -115,7 +116,7 @@ class TestDataSourceApi: assert status == 200 assert binding.disabled is False - def test_patch_disable_binding(self, app, patch_tenant, mock_engine): + def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -136,7 +137,7 @@ class TestDataSourceApi: assert status == 200 assert binding.disabled is True - def test_patch_binding_not_found(self, app, patch_tenant, mock_engine): + def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -151,7 +152,7 @@ class TestDataSourceApi: with pytest.raises(NotFound): method(api, "b1", "enable") - def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine): + def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -168,7 +169,7 @@ class TestDataSourceApi: with pytest.raises(ValueError): method(api, "b1", "enable") - def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine): + def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -188,10 +189,10 @@ class TestDataSourceApi: class TestDataSourceNotionListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_credential_not_found(self, app, patch_tenant): + def test_get_credential_not_found(self, app: Flask, patch_tenant): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -205,7 +206,7 @@ class TestDataSourceNotionListApi: with pytest.raises(NotFound): method(api) - def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine): + def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -246,7 +247,7 @@ class TestDataSourceNotionListApi: assert status == 200 - def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine): + def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -299,7 +300,7 @@ class TestDataSourceNotionListApi: assert status == 200 - def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine): + def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -323,10 +324,10 @@ class TestDataSourceNotionListApi: class TestDataSourceNotionApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_preview_success(self, app, patch_tenant): + def test_get_preview_success(self, app: Flask, patch_tenant): api = DataSourceNotionApi() method = unwrap(api.get) @@ -347,7 +348,7 @@ class TestDataSourceNotionApi: assert status == 200 - def test_post_indexing_estimate_success(self, app, patch_tenant): + def test_post_indexing_estimate_success(self, app: Flask, patch_tenant): api = DataSourceNotionApi() method = unwrap(api.post) @@ -381,10 +382,10 @@ class TestDataSourceNotionApi: class TestDataSourceNotionDatasetSyncApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, patch_tenant): + def test_get_success(self, app: Flask, patch_tenant): api = DataSourceNotionDatasetSyncApi() method = unwrap(api.get) @@ -407,7 +408,7 @@ class TestDataSourceNotionDatasetSyncApi: assert status == 200 - def test_get_dataset_not_found(self, app, patch_tenant): + def test_get_dataset_not_found(self, app: Flask, patch_tenant): api = DataSourceNotionDatasetSyncApi() method = unwrap(api.get) @@ -424,10 +425,10 @@ class TestDataSourceNotionDatasetSyncApi: class TestDataSourceNotionDocumentSyncApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, patch_tenant): + def test_get_success(self, app: Flask, patch_tenant): api = DataSourceNotionDocumentSyncApi() method = unwrap(api.get) @@ -450,7 +451,7 @@ class TestDataSourceNotionDocumentSyncApi: assert status == 200 - def test_get_document_not_found(self, app, patch_tenant): + def test_get_document_not_found(self, app: Flask, patch_tenant): api = DataSourceNotionDocumentSyncApi() method = unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py index 83492048ef..917aa35fe6 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound import controllers.console.explore.conversation as conversation_module @@ -53,10 +54,10 @@ def user(): class TestConversationListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, chat_app, user): + def test_get_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -81,7 +82,7 @@ class TestConversationListApi: assert result["has_more"] is False assert len(result["data"]) == 2 - def test_last_conversation_not_exists(self, app, chat_app, user): + def test_last_conversation_not_exists(self, app: Flask, chat_app, user): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -97,7 +98,7 @@ class TestConversationListApi: with pytest.raises(NotFound): method(chat_app) - def test_wrong_app_mode(self, app, non_chat_app): + def test_wrong_app_mode(self, app: Flask, non_chat_app): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -108,10 +109,10 @@ class TestConversationListApi: class TestConversationApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_delete_success(self, app, chat_app, user): + def test_delete_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -129,7 +130,7 @@ class TestConversationApi: assert status == 204 assert body["result"] == "success" - def test_delete_not_found(self, app, chat_app, user): + def test_delete_not_found(self, app: Flask, chat_app, user): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -145,7 +146,7 @@ class TestConversationApi: with pytest.raises(NotFound): method(chat_app, "cid") - def test_delete_wrong_app_mode(self, app, non_chat_app): + def test_delete_wrong_app_mode(self, app: Flask, non_chat_app): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -156,10 +157,10 @@ class TestConversationApi: class TestConversationRenameApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_rename_success(self, app, chat_app, user): + def test_rename_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationRenameApi() method = unwrap(api.post) @@ -178,7 +179,7 @@ class TestConversationRenameApi: assert result["id"] == "cid" - def test_rename_not_found(self, app, chat_app, user): + def test_rename_not_found(self, app: Flask, chat_app, user): api = conversation_module.ConversationRenameApi() method = unwrap(api.post) @@ -197,10 +198,10 @@ class TestConversationRenameApi: class TestConversationPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_pin_success(self, app, chat_app, user): + def test_pin_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationPinApi() method = unwrap(api.patch) @@ -219,10 +220,10 @@ class TestConversationPinApi: class TestConversationUnPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_unpin_success(self, app, chat_app, user): + def test_unpin_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationUnPinApi() method = unwrap(api.patch) diff --git a/api/tests/test_containers_integration_tests/controllers/console/helpers.py b/api/tests/test_containers_integration_tests/controllers/console/helpers.py index 9e2084f393..a8ecf94da1 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/helpers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/helpers.py @@ -11,7 +11,7 @@ from constants import HEADER_NAME_CSRF_TOKEN from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin -from models.account import AccountStatus, TenantAccountRole +from models.account import AccountStatus, TenantAccountRole, TenantStatus from models.model import App, AppMode from services.account_service import AccountService @@ -37,7 +37,7 @@ def create_console_account_and_tenant(db_session: Session) -> tuple[Account, Ten db_session.add(account) db_session.commit() - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) db_session.add(tenant) db_session.commit() diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py index f2e7104b18..d944613886 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py @@ -6,6 +6,7 @@ import json from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden from controllers.console.workspace.tool_providers import ( @@ -60,7 +61,7 @@ def _mock_user_tenant(): @pytest.fixture -def client(flask_app_with_containers): +def client(flask_app_with_containers: Flask): return flask_app_with_containers.test_client() @@ -147,10 +148,10 @@ class TestUtils: class TestToolProviderListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ToolProviderListApi() method = unwrap(api.get) @@ -170,10 +171,10 @@ class TestToolProviderListApi: class TestBuiltinProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_list_tools(self, app): + def test_list_tools(self, app: Flask): api = ToolBuiltinProviderListToolsApi() method = unwrap(api.get) @@ -190,7 +191,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == [{"a": 1}] - def test_info(self, app): + def test_info(self, app: Flask): api = ToolBuiltinProviderInfoApi() method = unwrap(api.get) @@ -207,7 +208,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"x": 1} - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolBuiltinProviderDeleteApi() method = unwrap(api.post) @@ -224,7 +225,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["result"] == "success" - def test_add_invalid_type(self, app): + def test_add_invalid_type(self, app: Flask): api = ToolBuiltinProviderAddApi() method = unwrap(api.post) @@ -238,7 +239,7 @@ class TestBuiltinProviderApis: with pytest.raises(ValueError): method(api, "provider") - def test_add_success(self, app): + def test_add_success(self, app: Flask): api = ToolBuiltinProviderAddApi() method = unwrap(api.post) @@ -257,7 +258,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["id"] == 1 - def test_update(self, app): + def test_update(self, app: Flask): api = ToolBuiltinProviderUpdateApi() method = unwrap(api.post) @@ -276,7 +277,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["ok"] - def test_get_credentials(self, app): + def test_get_credentials(self, app: Flask): api = ToolBuiltinProviderGetCredentialsApi() method = unwrap(api.get) @@ -293,7 +294,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"k": "v"} - def test_icon(self, app): + def test_icon(self, app: Flask): api = ToolBuiltinProviderIconApi() method = unwrap(api.get) @@ -307,7 +308,7 @@ class TestBuiltinProviderApis: response = method(api, "provider") assert response.mimetype == "image/png" - def test_credentials_schema(self, app): + def test_credentials_schema(self, app: Flask): api = ToolBuiltinProviderCredentialsSchemaApi() method = unwrap(api.get) @@ -324,7 +325,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider", "oauth2") == {"schema": {}} - def test_set_default_credential(self, app): + def test_set_default_credential(self, app: Flask): api = ToolBuiltinProviderSetDefaultApi() method = unwrap(api.post) @@ -341,7 +342,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["ok"] - def test_get_credential_info(self, app): + def test_get_credential_info(self, app: Flask): api = ToolBuiltinProviderGetCredentialInfoApi() method = unwrap(api.get) @@ -358,7 +359,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"info": "x"} - def test_get_oauth_client_schema(self, app): + def test_get_oauth_client_schema(self, app: Flask): api = ToolBuiltinProviderGetOauthClientSchemaApi() method = unwrap(api.get) @@ -378,10 +379,10 @@ class TestBuiltinProviderApis: class TestApiProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_add(self, app): + def test_add(self, app: Flask): api = ToolApiProviderAddApi() method = unwrap(api.post) @@ -406,7 +407,7 @@ class TestApiProviderApis: ): assert method(api)["id"] == 1 - def test_remote_schema(self, app): + def test_remote_schema(self, app: Flask): api = ToolApiProviderGetRemoteSchemaApi() method = unwrap(api.get) @@ -423,7 +424,7 @@ class TestApiProviderApis: ): assert method(api)["schema"] == "x" - def test_list_tools(self, app): + def test_list_tools(self, app: Flask): api = ToolApiProviderListToolsApi() method = unwrap(api.get) @@ -440,7 +441,7 @@ class TestApiProviderApis: ): assert method(api) == [{"tool": 1}] - def test_update(self, app): + def test_update(self, app: Flask): api = ToolApiProviderUpdateApi() method = unwrap(api.post) @@ -468,7 +469,7 @@ class TestApiProviderApis: ): assert method(api)["ok"] - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolApiProviderDeleteApi() method = unwrap(api.post) @@ -485,7 +486,7 @@ class TestApiProviderApis: ): assert method(api)["result"] == "success" - def test_get(self, app): + def test_get(self, app: Flask): api = ToolApiProviderGetApi() method = unwrap(api.get) @@ -505,10 +506,10 @@ class TestApiProviderApis: class TestWorkflowApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_create(self, app): + def test_create(self, app: Flask): api = ToolWorkflowProviderCreateApi() method = unwrap(api.post) @@ -534,7 +535,7 @@ class TestWorkflowApis: ): assert method(api)["id"] == 1 - def test_update_invalid(self, app): + def test_update_invalid(self, app: Flask): api = ToolWorkflowProviderUpdateApi() method = unwrap(api.post) @@ -560,7 +561,7 @@ class TestWorkflowApis: result = method(api) assert result["ok"] - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolWorkflowProviderDeleteApi() method = unwrap(api.post) @@ -577,7 +578,7 @@ class TestWorkflowApis: ): assert method(api)["ok"] - def test_get_error(self, app): + def test_get_error(self, app: Flask): api = ToolWorkflowProviderGetApi() method = unwrap(api.get) @@ -594,10 +595,10 @@ class TestWorkflowApis: class TestLists: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_builtin_list(self, app): + def test_builtin_list(self, app: Flask): api = ToolBuiltinListApi() method = unwrap(api.get) @@ -617,7 +618,7 @@ class TestLists: ): assert method(api) == [{"x": 1}] - def test_api_list(self, app): + def test_api_list(self, app: Flask): api = ToolApiListApi() method = unwrap(api.get) @@ -637,7 +638,7 @@ class TestLists: ): assert method(api) == [{"x": 1}] - def test_workflow_list(self, app): + def test_workflow_list(self, app: Flask): api = ToolWorkflowListApi() method = unwrap(api.get) @@ -660,10 +661,10 @@ class TestLists: class TestLabels: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_labels(self, app): + def test_labels(self, app: Flask): api = ToolLabelsApi() method = unwrap(api.get) @@ -679,10 +680,10 @@ class TestLabels: class TestOAuth: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_oauth_no_client(self, app): + def test_oauth_no_client(self, app: Flask): api = ToolPluginOAuthApi() method = unwrap(api.get) @@ -700,7 +701,7 @@ class TestOAuth: with pytest.raises(Forbidden): method(api, "provider") - def test_oauth_callback_no_cookie(self, app): + def test_oauth_callback_no_cookie(self, app: Flask): api = ToolOAuthCallback() method = unwrap(api.get) @@ -711,10 +712,10 @@ class TestOAuth: class TestOAuthCustomClient: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_save_custom_client(self, app): + def test_save_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.post) @@ -731,7 +732,7 @@ class TestOAuthCustomClient: ): assert method(api, "provider")["ok"] - def test_get_custom_client(self, app): + def test_get_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.get) @@ -748,7 +749,7 @@ class TestOAuthCustomClient: ): assert method(api, "provider") == {"client_id": "x"} - def test_delete_custom_client(self, app): + def test_delete_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.delete) diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py index ca8195af53..e41adccf3c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, Forbidden from controllers.console.workspace.trigger_providers import ( @@ -45,10 +46,10 @@ def mock_user(): class TestTriggerProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_icon_success(self, app): + def test_icon_success(self, app: Flask): api = TriggerProviderIconApi() method = unwrap(api.get) @@ -62,7 +63,7 @@ class TestTriggerProviderApis: ): assert method(api, "github") == "icon" - def test_list_providers(self, app): + def test_list_providers(self, app: Flask): api = TriggerProviderListApi() method = unwrap(api.get) @@ -76,7 +77,7 @@ class TestTriggerProviderApis: ): assert method(api) == [] - def test_provider_info(self, app): + def test_provider_info(self, app: Flask): api = TriggerProviderInfoApi() method = unwrap(api.get) @@ -93,10 +94,10 @@ class TestTriggerProviderApis: class TestTriggerSubscriptionListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = TriggerSubscriptionListApi() method = unwrap(api.get) @@ -110,7 +111,7 @@ class TestTriggerSubscriptionListApi: ): assert method(api, "github") == [] - def test_list_invalid_provider(self, app): + def test_list_invalid_provider(self, app: Flask): api = TriggerSubscriptionListApi() method = unwrap(api.get) @@ -128,10 +129,10 @@ class TestTriggerSubscriptionListApi: class TestTriggerSubscriptionBuilderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_create_builder(self, app): + def test_create_builder(self, app: Flask): api = TriggerSubscriptionBuilderCreateApi() method = unwrap(api.post) @@ -146,7 +147,7 @@ class TestTriggerSubscriptionBuilderApis: result = method(api, "github") assert "subscription_builder" in result - def test_get_builder(self, app): + def test_get_builder(self, app: Flask): api = TriggerSubscriptionBuilderGetApi() method = unwrap(api.get) @@ -159,7 +160,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert method(api, "github", "b1") == {"id": "b1"} - def test_verify_builder(self, app): + def test_verify_builder(self, app: Flask): api = TriggerSubscriptionBuilderVerifyApi() method = unwrap(api.post) @@ -173,7 +174,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert method(api, "github", "b1") == {"ok": True} - def test_verify_builder_error(self, app): + def test_verify_builder_error(self, app: Flask): api = TriggerSubscriptionBuilderVerifyApi() method = unwrap(api.post) @@ -188,7 +189,7 @@ class TestTriggerSubscriptionBuilderApis: with pytest.raises(ValueError): method(api, "github", "b1") - def test_update_builder(self, app): + def test_update_builder(self, app: Flask): api = TriggerSubscriptionBuilderUpdateApi() method = unwrap(api.post) @@ -202,7 +203,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert method(api, "github", "b1") == {"id": "b1"} - def test_logs(self, app): + def test_logs(self, app: Flask): api = TriggerSubscriptionBuilderLogsApi() method = unwrap(api.get) @@ -219,7 +220,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert "logs" in method(api, "github", "b1") - def test_build(self, app): + def test_build(self, app: Flask): api = TriggerSubscriptionBuilderBuildApi() method = unwrap(api.post) @@ -236,10 +237,10 @@ class TestTriggerSubscriptionBuilderApis: class TestTriggerSubscriptionCrud: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_update_rename_only(self, app): + def test_update_rename_only(self, app: Flask): api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) @@ -258,7 +259,7 @@ class TestTriggerSubscriptionCrud: ): assert method(api, "s1") == 200 - def test_update_not_found(self, app): + def test_update_not_found(self, app: Flask): api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) @@ -273,7 +274,7 @@ class TestTriggerSubscriptionCrud: with pytest.raises(NotFoundError): method(api, "x") - def test_update_rebuild(self, app): + def test_update_rebuild(self, app: Flask): api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) @@ -296,7 +297,7 @@ class TestTriggerSubscriptionCrud: ): assert method(api, "s1") == 200 - def test_delete_subscription(self, app): + def test_delete_subscription(self, app: Flask): api = TriggerSubscriptionDeleteApi() method = unwrap(api.post) @@ -319,7 +320,7 @@ class TestTriggerSubscriptionCrud: assert result["result"] == "success" - def test_delete_subscription_value_error(self, app): + def test_delete_subscription_value_error(self, app: Flask): api = TriggerSubscriptionDeleteApi() method = unwrap(api.post) @@ -342,10 +343,10 @@ class TestTriggerSubscriptionCrud: class TestTriggerOAuthApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_oauth_authorize_success(self, app): + def test_oauth_authorize_success(self, app: Flask): api = TriggerOAuthAuthorizeApi() method = unwrap(api.get) @@ -372,7 +373,7 @@ class TestTriggerOAuthApis: resp = method(api, "github") assert resp.status_code == 200 - def test_oauth_authorize_no_client(self, app): + def test_oauth_authorize_no_client(self, app: Flask): api = TriggerOAuthAuthorizeApi() method = unwrap(api.get) @@ -387,7 +388,7 @@ class TestTriggerOAuthApis: with pytest.raises(NotFoundError): method(api, "github") - def test_oauth_callback_forbidden(self, app): + def test_oauth_callback_forbidden(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -395,7 +396,7 @@ class TestTriggerOAuthApis: with pytest.raises(Forbidden): method(api, "github") - def test_oauth_callback_success(self, app): + def test_oauth_callback_success(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -425,7 +426,7 @@ class TestTriggerOAuthApis: resp = method(api, "github") assert resp.status_code == 302 - def test_oauth_callback_no_oauth_client(self, app): + def test_oauth_callback_no_oauth_client(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -449,7 +450,7 @@ class TestTriggerOAuthApis: with pytest.raises(Forbidden): method(api, "github") - def test_oauth_callback_empty_credentials(self, app): + def test_oauth_callback_empty_credentials(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -480,10 +481,10 @@ class TestTriggerOAuthApis: class TestTriggerOAuthClientManageApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_client(self, app): + def test_get_client(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.get) @@ -510,7 +511,7 @@ class TestTriggerOAuthClientManageApi: result = method(api, "github") assert "configured" in result - def test_post_client(self, app): + def test_post_client(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.post) @@ -524,7 +525,7 @@ class TestTriggerOAuthClientManageApi: ): assert method(api, "github") == {"ok": True} - def test_delete_client(self, app): + def test_delete_client(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.delete) @@ -538,7 +539,7 @@ class TestTriggerOAuthClientManageApi: ): assert method(api, "github") == {"ok": True} - def test_oauth_client_post_value_error(self, app): + def test_oauth_client_post_value_error(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.post) @@ -556,10 +557,10 @@ class TestTriggerOAuthClientManageApi: class TestTriggerSubscriptionVerifyApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_verify_success(self, app): + def test_verify_success(self, app: Flask): api = TriggerSubscriptionVerifyApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 9b913d6d3d..b73d28e4c4 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -18,6 +18,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound @@ -217,10 +218,20 @@ class TestTagUnbindingPayload: """Test suite for TagUnbindingPayload Pydantic model.""" def test_payload_with_valid_data(self): - payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456") - assert payload.tag_id == "tag_123" + payload = TagUnbindingPayload(tag_ids=["tag_123"], target_id="dataset_456") + assert payload.tag_ids == ["tag_123"] assert payload.target_id == "dataset_456" + def test_payload_normalizes_legacy_tag_id(self): + payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456") + assert payload.tag_ids == ["tag_123"] + assert payload.target_id == "dataset_456" + + def test_payload_rejects_empty_tag_ids(self): + with pytest.raises(ValueError) as exc_info: + TagUnbindingPayload(tag_ids=[], target_id="dataset_456") + assert "Tag IDs is required" in str(exc_info.value) + # --------------------------------------------------------------------------- # Helpers @@ -236,7 +247,7 @@ def _unwrap(method): @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): # Uses the full containerised app so that Flask config, extensions, and # blueprint registrations match production. Most tests mock the service # layer to isolate controller logic; a few (e.g. test_list_tags_from_db) @@ -280,7 +291,7 @@ class TestDatasetListApiGet: mock_current_user, mock_provider_mgr, mock_marshal, - app, + app: Flask, mock_tenant, ): from controllers.service_api.dataset.dataset import DatasetListApi @@ -315,7 +326,7 @@ class TestDatasetListApiPost: mock_dataset_svc, mock_current_user, mock_marshal, - app, + app: Flask, mock_tenant, ): from controllers.service_api.dataset.dataset import DatasetListApi @@ -341,7 +352,7 @@ class TestDatasetListApiPost: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_tenant, ): from controllers.service_api.dataset.dataset import DatasetListApi @@ -379,7 +390,7 @@ class TestDatasetApiGet: mock_provider_mgr, mock_marshal, mock_perm_svc, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -429,7 +440,7 @@ class TestDatasetApiGet: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -457,7 +468,7 @@ class TestDatasetApiDelete: mock_dataset_svc, mock_current_user, mock_perm_svc, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -479,7 +490,7 @@ class TestDatasetApiDelete: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -500,7 +511,7 @@ class TestDatasetApiDelete: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -532,7 +543,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -563,7 +574,7 @@ class TestDocumentStatusApiPatch: def test_batch_update_status_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -592,7 +603,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -625,7 +636,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -658,7 +669,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -698,7 +709,7 @@ class TestDatasetTagsApiGet: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -720,7 +731,7 @@ class TestDatasetTagsApiGet: def test_list_tags_from_db( self, mock_current_user, - app, + app: Flask, db_session_with_containers: Session, ): """Integration test: creates real Tag rows and retrieves them @@ -763,7 +774,7 @@ class TestDatasetTagsApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -786,7 +797,7 @@ class TestDatasetTagsApiPost: mock_tag_svc.save_tags.assert_called_once() @patch("controllers.service_api.dataset.dataset.current_user") - def test_create_tag_forbidden(self, mock_current_user, app): + def test_create_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -815,7 +826,7 @@ class TestDatasetTagsApiPatch: mock_current_user, mock_service_api_ns, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -841,7 +852,7 @@ class TestDatasetTagsApiPatch: mock_tag_svc.update_tags.assert_called_once_with({"name": "Updated Tag", "type": "knowledge"}, "tag-1") @patch("controllers.service_api.dataset.dataset.current_user") - def test_update_tag_forbidden(self, mock_current_user, app): + def test_update_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -869,7 +880,7 @@ class TestDatasetTagsApiDelete: mock_current_user, mock_service_api_ns, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -894,7 +905,7 @@ class TestDatasetTagsApiDelete: mock_tag_svc.delete_tag.assert_called_once_with("tag-1") @patch("libs.login.current_user") - def test_delete_tag_forbidden(self, mock_current_user, app): + def test_delete_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagsApi user_obj = Mock(spec=Account) @@ -922,7 +933,7 @@ class TestDatasetTagsBindingStatusApi: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi @@ -952,7 +963,7 @@ class TestDatasetTagBindingApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagBindingApi @@ -977,7 +988,7 @@ class TestDatasetTagBindingApiPost: ) @patch("controllers.service_api.dataset.dataset.current_user") - def test_bind_tags_forbidden(self, mock_current_user, app): + def test_bind_tags_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagBindingApi mock_current_user.__class__ = Account @@ -1003,7 +1014,37 @@ class TestDatasetTagUnbindingApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, + ): + from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi + + mock_current_user.__class__ = Account + mock_current_user.has_edit_permission = True + mock_current_user.is_dataset_editor = True + mock_tag_svc.delete_tag_binding.return_value = None + + with app.test_request_context( + "/datasets/tags/unbinding", + method="POST", + json={"tag_ids": ["tag-1"], "target_id": "ds-1"}, + ): + api = DatasetTagUnbindingApi() + result = api.post(_=None) + + assert result == ("", 204) + from services.tag_service import TagBindingDeletePayload + + mock_tag_svc.delete_tag_binding.assert_called_once_with( + TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge") + ) + + @patch("controllers.service_api.dataset.dataset.TagService") + @patch("controllers.service_api.dataset.dataset.current_user") + def test_unbind_legacy_tag_id_success( + self, + mock_current_user, + mock_tag_svc, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi @@ -1024,11 +1065,11 @@ class TestDatasetTagUnbindingApiPost: from services.tag_service import TagBindingDeletePayload mock_tag_svc.delete_tag_binding.assert_called_once_with( - TagBindingDeletePayload(tag_id="tag-1", target_id="ds-1", type="knowledge") + TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge") ) @patch("controllers.service_api.dataset.dataset.current_user") - def test_unbind_tag_forbidden(self, mock_current_user, app): + def test_unbind_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi mock_current_user.__class__ = Account @@ -1038,7 +1079,7 @@ class TestDatasetTagUnbindingApiPost: with app.test_request_context( "/datasets/tags/unbinding", method="POST", - json={"tag_id": "tag-1", "target_id": "ds-1"}, + json={"tag_ids": ["tag-1"], "target_id": "ds-1"}, ): api = DatasetTagUnbindingApi() with pytest.raises(Forbidden): diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py new file mode 100644 index 0000000000..4e884626a7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py @@ -0,0 +1,110 @@ +""" +Testcontainers integration tests for Service API Site controller. +""" + +from __future__ import annotations + +import pytest +from flask import Flask +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.service_api.app.site import AppSiteApi +from models.account import Tenant, TenantStatus +from models.model import App, AppMode, Site + + +@pytest.fixture +def app(flask_app_with_containers) -> Flask: + return flask_app_with_containers + + +def _unwrap(method): + fn = method + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + return fn + + +def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: + tenant = Tenant(name="service-api-site-tenant", status=status) + db_session.add(tenant) + db_session.commit() + return tenant + + +def _create_app(db_session: Session, tenant_id: str) -> App: + app_model = App( + tenant_id=tenant_id, + mode=AppMode.CHAT, + name="service-api-site-app", + enable_site=True, + enable_api=True, + status="normal", + ) + db_session.add(app_model) + db_session.commit() + return app_model + + +def _create_site(db_session: Session, app_id: str) -> Site: + site = Site( + app_id=app_id, + title="Service API Site", + icon_type="emoji", + icon="robot", + icon_background="#ffffff", + description="Service API test site", + default_language="en-US", + prompt_public=True, + show_workflow_steps=True, + customize_token_strategy="not_allow", + use_icon_as_answer_icon=False, + chat_color_theme="light", + chat_color_theme_inverted=False, + ) + db_session.add(site) + db_session.commit() + return site + + +class TestAppSiteApi: + def test_get_site_success(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + response = _unwrap(api.get)(api, app_model=app_model) + + assert response["title"] == "Service API Site" + assert response["icon"] == "robot" + assert response["description"] == "Service API test site" + + def test_get_site_not_found(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + with pytest.raises(Forbidden): + _unwrap(api.get)(api, app_model=app_model) + + def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) + + archived_tenant = db_session_with_containers.get(Tenant, tenant.id) + assert archived_tenant is not None + archived_tenant.status = TenantStatus.ARCHIVE + db_session_with_containers.commit() + + app_model = db_session_with_containers.get(App, app_model.id) + assert app_model is not None + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + with pytest.raises(Forbidden): + _unwrap(api.get)(api, app_model=app_model) diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py index e1e6741014..c34da27ebe 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.web.conversation import ( @@ -34,16 +35,16 @@ def _end_user() -> SimpleNamespace: class TestConversationListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context("/conversations"): with pytest.raises(NotChatAppError): ConversationListApi().get(_completion_app(), _end_user()) @patch("controllers.web.conversation.WebConversationService.pagination_by_last_id") - def test_happy_path(self, mock_paginate: MagicMock, app) -> None: + def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None: conv_id = str(uuid4()) conv = SimpleNamespace( id=conv_id, @@ -65,16 +66,16 @@ class TestConversationListApi: class TestConversationApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}"): with pytest.raises(NotChatAppError): ConversationApi().delete(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.ConversationService.delete") - def test_delete_success(self, mock_delete: MagicMock, app) -> None: + def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}"): result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id) @@ -83,7 +84,7 @@ class TestConversationApi: assert result["result"] == "success" @patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError()) - def test_delete_not_found(self, mock_delete: MagicMock, app) -> None: + def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}"): with pytest.raises(NotFound, match="Conversation Not Exists"): @@ -92,17 +93,17 @@ class TestConversationApi: class TestConversationRenameApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}): with pytest.raises(NotChatAppError): ConversationRenameApi().post(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.ConversationService.rename") @patch("controllers.web.conversation.web_ns") - def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None: + def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: c_id = uuid4() mock_ns.payload = {"name": "New Name", "auto_generate": False} conv = SimpleNamespace( @@ -126,7 +127,7 @@ class TestConversationRenameApi: side_effect=ConversationNotExistsError(), ) @patch("controllers.web.conversation.web_ns") - def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None: + def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: c_id = uuid4() mock_ns.payload = {"name": "X", "auto_generate": False} @@ -137,16 +138,16 @@ class TestConversationRenameApi: class TestConversationPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"): with pytest.raises(NotChatAppError): ConversationPinApi().patch(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.WebConversationService.pin") - def test_pin_success(self, mock_pin: MagicMock, app) -> None: + def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id) @@ -154,7 +155,7 @@ class TestConversationPinApi: assert result["result"] == "success" @patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError()) - def test_pin_not_found(self, mock_pin: MagicMock, app) -> None: + def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): with pytest.raises(NotFound): @@ -163,16 +164,16 @@ class TestConversationPinApi: class TestConversationUnPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"): with pytest.raises(NotChatAppError): ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.WebConversationService.unpin") - def test_unpin_success(self, mock_unpin: MagicMock, app) -> None: + def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"): result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id) diff --git a/api/tests/unit_tests/controllers/web/test_site.py b/api/tests/test_containers_integration_tests/controllers/web/test_site.py similarity index 51% rename from api/tests/unit_tests/controllers/web/test_site.py rename to api/tests/test_containers_integration_tests/controllers/web/test_site.py index 6e9d754c43..9adb26ff3d 100644 --- a/api/tests/unit_tests/controllers/web/test_site.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_site.py @@ -1,28 +1,48 @@ -"""Unit tests for controllers.web.site endpoints.""" +"""Testcontainers integration tests for controllers.web.site endpoints.""" from __future__ import annotations from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask import Flask +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from controllers.web.site import AppSiteApi, AppSiteInfo +from models import Tenant, TenantStatus +from models.model import App, AppMode, CustomizeTokenStrategy, Site -def _tenant(*, status: str = "normal") -> SimpleNamespace: - return SimpleNamespace( - id="tenant-1", - status=status, - plan="basic", - custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False}, +@pytest.fixture +def app(flask_app_with_containers) -> Flask: + return flask_app_with_containers + + +def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: + tenant = Tenant(name="test-tenant", status=status) + db_session.add(tenant) + db_session.commit() + return tenant + + +def _create_app(db_session: Session, tenant_id: str, *, enable_site: bool = True) -> App: + app_model = App( + tenant_id=tenant_id, + mode=AppMode.CHAT, + name="test-app", + enable_site=enable_site, + enable_api=True, ) + db_session.add(app_model) + db_session.commit() + return app_model -def _site() -> SimpleNamespace: - return SimpleNamespace( +def _create_site(db_session: Session, app_id: str) -> Site: + site = Site( + app_id=app_id, title="Site", icon_type="emoji", icon="robot", @@ -31,77 +51,64 @@ def _site() -> SimpleNamespace: default_language="en", chat_color_theme="light", chat_color_theme_inverted=False, - copyright=None, - privacy_policy=None, - custom_disclaimer=None, + customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW, + code=f"code-{app_id[-6:]}", prompt_public=False, show_workflow_steps=True, use_icon_as_answer_icon=False, ) + db_session.add(site) + db_session.commit() + return site -# --------------------------------------------------------------------------- -# AppSiteApi -# --------------------------------------------------------------------------- class TestAppSiteApi: @patch("controllers.web.site.FeatureService.get_features") - @patch("controllers.web.site.db") - def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None: + def test_happy_path(self, mock_features, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_features.return_value = SimpleNamespace(can_replace_logo=False) - site_obj = _site() - mock_db.session.scalar.return_value = site_obj - tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) end_user = SimpleNamespace(id="eu-1") + mock_features.return_value = SimpleNamespace(can_replace_logo=False) with app.test_request_context("/site"): result = AppSiteApi().get(app_model, end_user) - # marshal_with serializes AppSiteInfo to a dict - assert result["app_id"] == "app-1" + assert result["app_id"] == app_model.id assert result["plan"] == "basic" assert result["enable_site"] is True - @patch("controllers.web.site.db") - def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + def test_missing_site_raises_forbidden(self, app: Flask, db_session_with_containers: Session) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - mock_db.session.scalar.return_value = None - tenant = _tenant() - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) end_user = SimpleNamespace(id="eu-1") with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) - @patch("controllers.web.site.db") - def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + @patch("controllers.web.site.FeatureService.get_features") + def test_archived_tenant_raises_forbidden( + self, mock_features, app: Flask, db_session_with_containers: Session + ) -> None: app.config["RESTX_MASK_HEADER"] = "X-Fields" - from models.account import TenantStatus - - mock_db.session.scalar.return_value = _site() - tenant = SimpleNamespace( - id="tenant-1", - status=TenantStatus.ARCHIVE, - plan="basic", - custom_config_dict={}, - ) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + tenant = _create_tenant(db_session_with_containers, status=TenantStatus.ARCHIVE) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) end_user = SimpleNamespace(id="eu-1") + mock_features.return_value = SimpleNamespace(can_replace_logo=False) with app.test_request_context("/site"): with pytest.raises(Forbidden): AppSiteApi().get(app_model, end_user) -# --------------------------------------------------------------------------- -# AppSiteInfo -# --------------------------------------------------------------------------- class TestAppSiteInfo: def test_basic_fields(self) -> None: - tenant = _tenant() - site_obj = _site() + tenant = SimpleNamespace(id="tenant-1", plan="basic", custom_config_dict={}) + site_obj = SimpleNamespace() info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False) assert info.app_id == "app-1" @@ -118,7 +125,7 @@ class TestAppSiteInfo: plan="pro", custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True}, ) - site_obj = _site() + site_obj = SimpleNamespace() info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True) assert info.can_replace_logo is True diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 04ad143103..2c6a990240 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.web.forgot_password import ( ForgotPasswordCheckApi, @@ -24,45 +25,39 @@ def _patch_wraps(): patch("controllers.console.wraps.dify_config", dify_settings), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), ): - mock_db.session.query.return_value.first.return_value = MagicMock() yield class TestForgotPasswordSendEmailApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.AccountService.send_reset_password_email") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.web.forgot_password.extract_remote_ip", return_value="127.0.0.1") - @patch("controllers.web.forgot_password.sessionmaker") def test_should_normalize_email_before_sending( self, - mock_session_cls, mock_extract_ip, mock_rate_limit, mock_get_account, mock_send_mail, - app, + app: Flask, ): mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_mail.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password", - method="POST", - json={"email": "User@Example.com", "language": "zh-Hans"}, - ): - response = ForgotPasswordSendEmailApi().post() + with app.test_request_context( + "/web/forgot-password", + method="POST", + json={"email": "User@Example.com", "language": "zh-Hans"}, + ): + response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_extract_ip.assert_called_once() mock_rate_limit.assert_called_once_with("127.0.0.1") @@ -70,7 +65,7 @@ class TestForgotPasswordSendEmailApi: class TestForgotPasswordCheckApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") @@ -87,7 +82,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"} @@ -123,7 +118,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"} @@ -148,49 +143,47 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") def test_should_fetch_account_with_fallback( self, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_get_account, mock_update_account, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = mock_account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "token-123", - "new_password": "ValidPass123!", - "password_confirm": "ValidPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "token-123", + "new_password": "ValidPass123!", + "password_confirm": "ValidPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) + mock_get_account.assert_called_once_with("User@Example.com") mock_update_account.assert_called_once() mock_revoke_token.assert_called_once_with("token-123") @patch("controllers.web.forgot_password.hash_password", return_value=b"hashed-value") @patch("controllers.web.forgot_password.secrets.token_bytes", return_value=b"0123456789abcdef") - @patch("controllers.web.forgot_password.sessionmaker") + @patch("controllers.web.forgot_password.db") @patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.web.forgot_password.AccountService.get_reset_password_data") @patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback") @@ -199,28 +192,26 @@ class TestForgotPasswordResetApi: mock_get_account, mock_get_reset_data, mock_revoke_token, - mock_session_cls, + mock_db, mock_token_bytes, mock_hash_password, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} account = MagicMock() mock_get_account.return_value = account - mock_session = MagicMock() - mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session + mock_db.session.merge.return_value = account - with patch("controllers.web.forgot_password.db", SimpleNamespace(engine="engine")): - with app.test_request_context( - "/web/forgot-password/resets", - method="POST", - json={ - "token": "reset-token", - "new_password": "StrongPass123!", - "password_confirm": "StrongPass123!", - }, - ): - response = ForgotPasswordResetApi().post() + with app.test_request_context( + "/web/forgot-password/resets", + method="POST", + json={ + "token": "reset-token", + "new_password": "StrongPass123!", + "password_confirm": "StrongPass123!", + }, + ): + response = ForgotPasswordResetApi().post() assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("reset-token") diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py index 19833cc772..0a4e495f36 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound, Unauthorized @@ -182,7 +183,7 @@ class TestValidateUserAccessibility: class TestDecodeJwtToken: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _create_app_site_enduser(self, db_session: Session, *, enable_site: bool = True): @@ -239,7 +240,7 @@ class TestDecodeJwtToken: mock_access_mode: MagicMock, mock_validate_token: MagicMock, mock_validate_user: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers) @@ -299,7 +300,7 @@ class TestDecodeJwtToken: mock_extract: MagicMock, mock_passport_cls: MagicMock, mock_features: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False) @@ -324,7 +325,7 @@ class TestDecodeJwtToken: mock_extract: MagicMock, mock_passport_cls: MagicMock, mock_features: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, _ = self._create_app_site_enduser(db_session_with_containers) @@ -350,7 +351,7 @@ class TestDecodeJwtToken: mock_extract: MagicMock, mock_passport_cls: MagicMock, mock_features: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 2b4c1b59ab..bd13527e14 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,13 +22,6 @@ import uuid from time import time import pytest -from graphon.entities.pause_reason import SchedulingPause -from graphon.enums import WorkflowExecutionStatus -from graphon.graph_engine.entities.commands import GraphEngineCommand -from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from graphon.graph_events import GraphRunPausedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -40,6 +33,13 @@ from core.app.layers.pause_state_persist_layer import ( ) from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel @@ -85,14 +85,14 @@ class TestPauseStatePersistenceLayerTestContainers: return WorkflowRunService(engine) @pytest.fixture(autouse=True) - def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service): + def setup_test_data(self, db_session_with_containers: Session, file_service, workflow_run_service): """Set up test data for each test method using TestContainers.""" # Create test tenant and account - from models.account import Tenant, TenantAccountJoin, TenantAccountRole + from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -101,7 +101,7 @@ class TestPauseStatePersistenceLayerTestContainers: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -295,7 +295,7 @@ class TestPauseStatePersistenceLayerTestContainers: generate_entity=entity, ) - def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers): + def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers: Session): """Test complete pause flow: event -> state serialization -> database save -> storage save.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -352,7 +352,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert isinstance(persisted_entity, WorkflowAppGenerateEntity) assert persisted_entity.workflow_execution_id == self.test_workflow_run_id - def test_state_persistence_and_retrieval(self, db_session_with_containers): + def test_state_persistence_and_retrieval(self, db_session_with_containers: Session): """Test that pause state can be persisted and retrieved correctly.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -402,7 +402,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert retrieved_state["node_run_steps"] == 10 assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id - def test_database_transaction_handling(self, db_session_with_containers): + def test_database_transaction_handling(self, db_session_with_containers: Session): """Test that database transactions are handled correctly.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -433,7 +433,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert pause_model.resumed_at is None assert pause_model.state_object_key != "" - def test_file_storage_integration(self, db_session_with_containers): + def test_file_storage_integration(self, db_session_with_containers: Session): """Test integration with file storage system.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -467,7 +467,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps() assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id - def test_workflow_with_different_creators(self, db_session_with_containers): + def test_workflow_with_different_creators(self, db_session_with_containers: Session): """Test pause state with workflows created by different users.""" # Arrange - Create workflow with different creator different_user_id = str(uuid.uuid4()) @@ -532,7 +532,7 @@ class TestPauseStatePersistenceLayerTestContainers: resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id - def test_layer_ignores_non_pause_events(self, db_session_with_containers): + def test_layer_ignores_non_pause_events(self, db_session_with_containers: Session): """Test that layer ignores non-pause events.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -557,14 +557,12 @@ class TestPauseStatePersistenceLayerTestContainers: self.session.refresh(self.test_workflow_run) assert self.test_workflow_run.status == WorkflowExecutionStatus.RUNNING - pause_states = ( - self.session.query(WorkflowPauseModel) - .filter(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) - .all() - ) + pause_states = self.session.scalars( + select(WorkflowPauseModel).where(WorkflowPauseModel.workflow_run_id == self.test_workflow_run_id) + ).all() assert len(pause_states) == 0 - def test_layer_requires_initialization(self, db_session_with_containers): + def test_layer_requires_initialization(self, db_session_with_containers: Session): """Test that layer requires proper initialization before handling events.""" # Arrange layer = self._create_pause_state_persistence_layer() diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py index a60159c66a..d1af0a56ef 100644 --- a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -15,11 +15,14 @@ from uuid import uuid4 import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue from extensions.ext_redis import redis_client from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus +TenantAndAccount = tuple[Tenant, Account] + @dataclass class TestTask: @@ -40,7 +43,7 @@ class TestTenantIsolatedTaskQueueIntegration: return Faker() @pytest.fixture - def test_tenant_and_account(self, db_session_with_containers, fake): + def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker): """Create test tenant and account for testing.""" # Create account account = Account( @@ -73,18 +76,18 @@ class TestTenantIsolatedTaskQueueIntegration: return tenant, account @pytest.fixture - def test_queue(self, test_tenant_and_account): + def test_queue(self, test_tenant_and_account: TenantAndAccount): """Create a generic test queue for testing.""" tenant, _ = test_tenant_and_account return TenantIsolatedTaskQueue(tenant.id, "test_queue") @pytest.fixture - def secondary_queue(self, test_tenant_and_account): + def secondary_queue(self, test_tenant_and_account: TenantAndAccount): """Create a secondary test queue for testing isolation.""" tenant, _ = test_tenant_and_account return TenantIsolatedTaskQueue(tenant.id, "secondary_queue") - def test_queue_initialization(self, test_tenant_and_account): + def test_queue_initialization(self, test_tenant_and_account: TenantAndAccount): """Test queue initialization with correct key generation.""" tenant, _ = test_tenant_and_account queue = TenantIsolatedTaskQueue(tenant.id, "test-key") @@ -94,7 +97,9 @@ class TestTenantIsolatedTaskQueueIntegration: assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}" assert queue._task_key == f"tenant_test-key_task:{tenant.id}" - def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake): + def test_tenant_isolation( + self, test_tenant_and_account: TenantAndAccount, db_session_with_containers: Session, fake: Faker + ): """Test that different tenants have isolated queues.""" tenant1, _ = test_tenant_and_account @@ -114,7 +119,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}" assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}" - def test_key_isolation(self, test_tenant_and_account): + def test_key_isolation(self, test_tenant_and_account: TenantAndAccount): """Test that different keys have isolated queues.""" tenant, _ = test_tenant_and_account queue1 = TenantIsolatedTaskQueue(tenant.id, "key1") @@ -176,7 +181,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert len(remaining_tasks) == 2 assert remaining_tasks == ["task4", "task5"] - def test_push_and_pull_complex_objects(self, test_queue, fake): + def test_push_and_pull_complex_objects(self, test_queue, fake: Faker): """Test pushing and pulling complex object tasks.""" # Create complex task objects as dictionaries (not dataclass instances) tasks = [ @@ -218,7 +223,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert pulled_task["data"] == original_task["data"] assert pulled_task["metadata"] == original_task["metadata"] - def test_mixed_task_types(self, test_queue, fake): + def test_mixed_task_types(self, test_queue, fake: Faker): """Test pushing and pulling mixed string and object tasks.""" string_task = "simple_string_task" object_task = { @@ -267,7 +272,7 @@ class TestTenantIsolatedTaskQueueIntegration: # Verify task key has expired assert test_queue.get_task_key() is None - def test_large_task_batch(self, test_queue, fake): + def test_large_task_batch(self, test_queue, fake: Faker): """Test handling large batches of tasks.""" # Create large batch of tasks large_batch = [] @@ -292,7 +297,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert isinstance(task, dict) assert task["index"] == i # FIFO order - def test_queue_operations_isolation(self, test_tenant_and_account, fake): + def test_queue_operations_isolation(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """Test concurrent operations on different queues.""" tenant, _ = test_tenant_and_account @@ -312,7 +317,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert tasks2 == ["task1_queue2", "task2_queue2"] assert tasks1 != tasks2 - def test_task_wrapper_serialization_roundtrip(self, test_queue, fake): + def test_task_wrapper_serialization_roundtrip(self, test_queue, fake: Faker): """Test TaskWrapper serialization and deserialization roundtrip.""" # Create complex nested data complex_data = { @@ -346,7 +351,7 @@ class TestTenantIsolatedTaskQueueIntegration: task = test_queue.pull_tasks(1) assert task[0] == invalid_json_task - def test_real_world_batch_processing_scenario(self, test_queue, fake): + def test_real_world_batch_processing_scenario(self, test_queue, fake: Faker): """Test realistic batch processing scenario.""" # Simulate batch processing tasks batch_tasks = [] @@ -403,7 +408,7 @@ class TestTenantIsolatedTaskQueueCompatibility: return Faker() @pytest.fixture - def test_tenant_and_account(self, db_session_with_containers, fake): + def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker): """Create test tenant and account for testing.""" # Create account account = Account( @@ -435,7 +440,7 @@ class TestTenantIsolatedTaskQueueCompatibility: return tenant, account - def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake): + def test_legacy_string_queue_compatibility(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """ Test compatibility with legacy queues containing only string data. @@ -465,7 +470,7 @@ class TestTenantIsolatedTaskQueueCompatibility: expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"] assert pulled_tasks == expected_order - def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake): + def test_legacy_queue_migration_scenario(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """ Test complete migration scenario from legacy to new system. @@ -546,7 +551,7 @@ class TestTenantIsolatedTaskQueueCompatibility: assert task["tenant_id"] == tenant.id assert task["processing_type"] == "new_system" - def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake): + def test_legacy_queue_error_recovery(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """ Test error recovery when legacy queue contains malformed data. diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 00d7496a40..9da6b04a2c 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -15,7 +16,7 @@ from tests.test_containers_integration_tests.helpers import generate_valid_passw class TestGetAvailableDatasetsIntegration: def test_returns_datasets_with_available_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -77,7 +78,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].name == dataset.name def test_filters_out_datasets_with_only_archived_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -130,7 +131,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_filters_out_datasets_with_only_disabled_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -183,7 +184,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_filters_out_datasets_with_non_completed_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -236,7 +237,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_includes_external_datasets_without_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that external datasets are returned even with no available documents. @@ -280,7 +281,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].id == dataset.id assert result[0].provider == "external" - def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_filters_by_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): # Arrange fake = Faker() @@ -356,7 +357,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].tenant_id == tenant1.id def test_returns_empty_list_when_no_datasets_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -379,7 +380,9 @@ class TestGetAvailableDatasetsIntegration: # Assert assert result == [] - def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies): + def test_returns_only_requested_dataset_ids( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): # Arrange fake = Faker() @@ -439,7 +442,7 @@ class TestGetAvailableDatasetsIntegration: class TestKnowledgeRetrievalIntegration: def test_knowledge_retrieval_with_available_datasets( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -507,7 +510,7 @@ class TestKnowledgeRetrievalIntegration: assert isinstance(result, list) def test_knowledge_retrieval_no_available_datasets( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -555,7 +558,7 @@ class TestKnowledgeRetrievalIntegration: assert result == [] def test_knowledge_retrieval_rate_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 13caad799e..6524d6ce61 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,12 +4,11 @@ from __future__ import annotations from uuid import uuid4 -from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -18,7 +17,15 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction +from models.account import ( + Account, + AccountStatus, + Tenant, + TenantAccountJoin, + TenantAccountRole, + TenantStatus, +) from models.human_input import ( EmailExternalRecipientPayload, EmailMemberRecipientPayload, @@ -29,7 +36,7 @@ from models.human_input import ( def _create_tenant_with_members(session: Session, member_emails: list[str]) -> tuple[Tenant, list[Account]]: - tenant = Tenant(name="Test Tenant", status="normal") + tenant = Tenant(name="Test Tenant", status=TenantStatus.NORMAL) session.add(tenant) session.flush() @@ -39,7 +46,7 @@ def _create_tenant_with_members(session: Session, member_emails: list[str]) -> t email=email, name=f"Member {index}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) session.add(account) session.flush() diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 0a9b476afc..5aed230cd4 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,6 +4,17 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowType from graphon.graph import Graph from graphon.graph_engine import GraphEngine @@ -16,20 +27,9 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool -from sqlalchemy import delete, select -from sqlalchemy.orm import Session - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from models import Account -from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import App, AppMode, IconType from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun @@ -101,8 +101,8 @@ def _build_graph( start_data = StartNodeData(title="start", variables=[]) start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, + node_id="start", + config=start_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -116,8 +116,8 @@ def _build_graph( ], ) human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, + node_id="human", + config=human_data, graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, @@ -130,8 +130,8 @@ def _build_graph( desc=None, ) end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, + node_id="end", + config=end_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -175,7 +175,7 @@ class TestHumanInputResumeNodeExecutionIntegration: def setup_test_data(self, db_session_with_containers: Session): tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -184,7 +184,7 @@ class TestHumanInputResumeNodeExecutionIntegration: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index cc72dc1cf3..26b80cebbb 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -1,22 +1,23 @@ -import unittest +from __future__ import annotations + from datetime import UTC, datetime from unittest.mock import patch from uuid import uuid4 import pytest -from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile from models.enums import CreatorUserRole @pytest.mark.usefixtures("flask_req_ctx_with_containers") -class TestStorageKeyLoader(unittest.TestCase): +class TestStorageKeyLoader: """ Integration tests for StorageKeyLoader class. @@ -24,110 +25,82 @@ class TestStorageKeyLoader(unittest.TestCase): with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE. """ - def setUp(self): - """Set up test data before each test method.""" - self.session = db.session() - self.tenant_id = str(uuid4()) - self.user_id = str(uuid4()) - self.conversation_id = str(uuid4()) - - # Create test data that will be cleaned up after each test - self.test_upload_files = [] - self.test_tool_files = [] - - # Create StorageKeyLoader instance - self.loader = StorageKeyLoader( - self.session, - self.tenant_id, - access_controller=DatabaseFileAccessController(), - ) - - def tearDown(self): - """Clean up test data after each test method.""" - self.session.rollback() + # ------------------------------------------------------------------ + # Per-test helpers (use db_session_with_containers as parameter) + # ------------------------------------------------------------------ + @staticmethod def _create_upload_file( - self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None + session: Session, + tenant_id: str, + user_id: str, + *, + file_id: str | None = None, + storage_key: str | None = None, + override_tenant_id: str | None = None, ) -> UploadFile: - """Helper method to create an UploadFile record for testing.""" - if file_id is None: - file_id = str(uuid4()) - if storage_key is None: - storage_key = f"test_storage_key_{uuid4()}" - if tenant_id is None: - tenant_id = self.tenant_id - + """Create and flush an UploadFile record for testing.""" upload_file = UploadFile( - tenant_id=tenant_id, + tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id, storage_type=StorageType.LOCAL, - key=storage_key, + key=storage_key or f"test_storage_key_{uuid4()}", name="test_file.txt", size=1024, extension=".txt", mime_type="text/plain", created_by_role=CreatorUserRole.ACCOUNT, - created_by=self.user_id, + created_by=user_id, created_at=datetime.now(UTC), used=False, ) - upload_file.id = file_id - - self.session.add(upload_file) - self.session.flush() - self.test_upload_files.append(upload_file) - + upload_file.id = file_id or str(uuid4()) + session.add(upload_file) + session.flush() return upload_file + @staticmethod def _create_tool_file( - self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None + session: Session, + tenant_id: str, + user_id: str, + conversation_id: str, + *, + file_id: str | None = None, + file_key: str | None = None, + override_tenant_id: str | None = None, ) -> ToolFile: - """Helper method to create a ToolFile record for testing.""" - if file_id is None: - file_id = str(uuid4()) - if file_key is None: - file_key = f"test_file_key_{uuid4()}" - if tenant_id is None: - tenant_id = self.tenant_id - + """Create and flush a ToolFile record for testing.""" tool_file = ToolFile( - user_id=self.user_id, - tenant_id=tenant_id, - conversation_id=self.conversation_id, - file_key=file_key, + user_id=user_id, + tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id, + conversation_id=conversation_id, + file_key=file_key or f"test_file_key_{uuid4()}", mimetype="text/plain", original_url="http://example.com/file.txt", name="test_tool_file.txt", size=2048, ) - tool_file.id = file_id - - self.session.add(tool_file) - self.session.flush() - self.test_tool_files.append(tool_file) - + tool_file.id = file_id or str(uuid4()) + session.add(tool_file) + session.flush() return tool_file - def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File: - """Helper method to create a File object for testing.""" - if tenant_id is None: - tenant_id = self.tenant_id - - # Set related_id for LOCAL_FILE and TOOL_FILE transfer methods - file_related_id = None - remote_url = None - - if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE): - file_related_id = related_id - elif transfer_method == FileTransferMethod.REMOTE_URL: - remote_url = "https://example.com/test_file.txt" - file_related_id = related_id - + @staticmethod + def _create_file( + tenant_id: str, + related_id: str, + transfer_method: FileTransferMethod, + *, + override_tenant_id: str | None = None, + ) -> File: + """Build a File value-object for testing.""" + remote_url = "https://example.com/test_file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None return File( - id=str(uuid4()), # Generate new UUID for File.id - tenant_id=tenant_id, - type=FileType.DOCUMENT, + file_id=str(uuid4()), + tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id, + file_type=FileType.DOCUMENT, transfer_method=transfer_method, - related_id=file_related_id, + related_id=related_id, remote_url=remote_url, filename="test_file.txt", extension=".txt", @@ -136,240 +109,280 @@ class TestStorageKeyLoader(unittest.TestCase): storage_key="initial_key", ) - def test_load_storage_keys_local_file(self): + # ------------------------------------------------------------------ + # Tests + # ------------------------------------------------------------------ + + def test_load_storage_keys_local_file(self, db_session_with_containers: Session): """Test loading storage keys for LOCAL_FILE transfer method.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + tenant_id = str(uuid4()) + user_id = str(uuid4()) - # Load storage keys - self.loader.load_storage_keys([file]) + upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id) + file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) + loader.load_storage_keys([file]) - # Verify storage key was loaded correctly assert file._storage_key == upload_file.key - def test_load_storage_keys_remote_url(self): + def test_load_storage_keys_remote_url(self, db_session_with_containers: Session): """Test loading storage keys for REMOTE_URL transfer method.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) + tenant_id = str(uuid4()) + user_id = str(uuid4()) - # Load storage keys - self.loader.load_storage_keys([file]) + upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id) + file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL) + + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) + loader.load_storage_keys([file]) - # Verify storage key was loaded correctly assert file._storage_key == upload_file.key - def test_load_storage_keys_tool_file(self): + def test_load_storage_keys_tool_file(self, db_session_with_containers: Session): """Test loading storage keys for TOOL_FILE transfer method.""" - # Create test data - tool_file = self._create_tool_file() - file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + tenant_id = str(uuid4()) + user_id = str(uuid4()) + conversation_id = str(uuid4()) - # Load storage keys - self.loader.load_storage_keys([file]) + tool_file = self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id) + file = self._create_file(tenant_id, related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) + loader.load_storage_keys([file]) - # Verify storage key was loaded correctly assert file._storage_key == tool_file.file_key - def test_load_storage_keys_mixed_methods(self): + def test_load_storage_keys_mixed_methods(self, db_session_with_containers: Session): """Test batch loading with mixed transfer methods.""" - # Create test data for different transfer methods - upload_file1 = self._create_upload_file() - upload_file2 = self._create_upload_file() - tool_file = self._create_tool_file() + tenant_id = str(uuid4()) + user_id = str(uuid4()) + conversation_id = str(uuid4()) - file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) - file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) + upload_file1 = self._create_upload_file(db_session_with_containers, tenant_id, user_id) + upload_file2 = self._create_upload_file(db_session_with_containers, tenant_id, user_id) + tool_file = self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id) - files = [file1, file2, file3] + file1 = self._create_file(tenant_id, related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(tenant_id, related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL) + file3 = self._create_file(tenant_id, related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE) - # Load storage keys - self.loader.load_storage_keys(files) + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) + loader.load_storage_keys([file1, file2, file3]) - # Verify all storage keys were loaded correctly assert file1._storage_key == upload_file1.key assert file2._storage_key == upload_file2.key assert file3._storage_key == tool_file.file_key - def test_load_storage_keys_empty_list(self): - """Test with empty file list.""" - # Should not raise any exceptions - self.loader.load_storage_keys([]) + def test_load_storage_keys_empty_list(self, db_session_with_containers: Session): + """Test with empty file list — should not raise.""" + tenant_id = str(uuid4()) + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) + loader.load_storage_keys([]) - def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + def test_load_storage_keys_ignores_legacy_file_tenant_id(self, db_session_with_containers: Session): """Legacy file tenant_id should not override the loader tenant scope.""" - upload_file = self._create_upload_file() + tenant_id = str(uuid4()) + user_id = str(uuid4()) + + upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id) file = self._create_file( - related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) + tenant_id, + related_id=upload_file.id, + transfer_method=FileTransferMethod.LOCAL_FILE, + override_tenant_id=str(uuid4()), ) - self.loader.load_storage_keys([file]) + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) + loader.load_storage_keys([file]) assert file._storage_key == upload_file.key - def test_load_storage_keys_missing_file_id(self): - """Test with None file.related_id.""" - # Create a file with valid parameters first, then manually set related_id to None - file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + def test_load_storage_keys_missing_file_id(self, db_session_with_containers: Session): + """Test with None file.related_id — should raise ValueError.""" + tenant_id = str(uuid4()) + user_id = str(uuid4()) + + upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id) + file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) file.related_id = None - # Should raise ValueError for None file related_id - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) + with pytest.raises(ValueError, match="file id should not be None."): + loader.load_storage_keys([file]) - assert str(context.value) == "file id should not be None." + def test_load_storage_keys_nonexistent_upload_file_records(self, db_session_with_containers: Session): + """Test with missing UploadFile database records — should raise ValueError.""" + tenant_id = str(uuid4()) + file = self._create_file(tenant_id, related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) - def test_load_storage_keys_nonexistent_upload_file_records(self): - """Test with missing UploadFile database records.""" - # Create file with non-existent upload file id - non_existent_id = str(uuid4()) - file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE) - - # Should raise ValueError for missing record + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) + loader.load_storage_keys([file]) - def test_load_storage_keys_nonexistent_tool_file_records(self): - """Test with missing ToolFile database records.""" - # Create file with non-existent tool file id - non_existent_id = str(uuid4()) - file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE) + def test_load_storage_keys_nonexistent_tool_file_records(self, db_session_with_containers: Session): + """Test with missing ToolFile database records — should raise ValueError.""" + tenant_id = str(uuid4()) + file = self._create_file(tenant_id, related_id=str(uuid4()), transfer_method=FileTransferMethod.TOOL_FILE) - # Should raise ValueError for missing record + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) + loader.load_storage_keys([file]) - def test_load_storage_keys_invalid_uuid(self): - """Test with invalid UUID format.""" - # Create a file with valid parameters first, then manually set invalid related_id - file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE) + def test_load_storage_keys_invalid_uuid(self, db_session_with_containers: Session): + """Test with invalid UUID format — should raise ValueError.""" + tenant_id = str(uuid4()) + user_id = str(uuid4()) + + upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id) + file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) file.related_id = "invalid-uuid-format" - # Should raise ValueError for invalid UUID + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) with pytest.raises(ValueError): - self.loader.load_storage_keys([file]) + loader.load_storage_keys([file]) - def test_load_storage_keys_batch_efficiency(self): - """Test batched operations use efficient queries.""" - # Create multiple files of different types - upload_files = [self._create_upload_file() for _ in range(3)] - tool_files = [self._create_tool_file() for _ in range(2)] + def test_load_storage_keys_batch_efficiency(self, db_session_with_containers: Session): + """Batched operations should issue exactly 2 queries for mixed file types.""" + tenant_id = str(uuid4()) + user_id = str(uuid4()) + conversation_id = str(uuid4()) - files = [] - files.extend( - [self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files] + upload_files = [self._create_upload_file(db_session_with_containers, tenant_id, user_id) for _ in range(3)] + tool_files = [ + self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id) for _ in range(2) + ] + + files = [ + self._create_file(tenant_id, related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) + for uf in upload_files + ] + [ + self._create_file(tenant_id, related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) + for tf in tool_files + ] + + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() ) - files.extend( - [self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files] - ) - - # Mock the session to count queries - with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars: - self.loader.load_storage_keys(files) - - # Should make exactly 2 queries (one for upload_files, one for tool_files) + with patch.object( + db_session_with_containers, "scalars", wraps=db_session_with_containers.scalars + ) as mock_scalars: + loader.load_storage_keys(files) + # Exactly 2 DB round-trips: one for UploadFile, one for ToolFile assert mock_scalars.call_count == 2 - # Verify all storage keys were loaded correctly for i, file in enumerate(files[:3]): assert file._storage_key == upload_files[i].key for i, file in enumerate(files[3:]): assert file._storage_key == tool_files[i].file_key - def test_load_storage_keys_tenant_isolation(self): - """Test that tenant isolation works correctly.""" - # Create files for different tenants + def test_load_storage_keys_tenant_isolation(self, db_session_with_containers: Session): + """Loader should not surface records belonging to a different tenant.""" + tenant_id = str(uuid4()) other_tenant_id = str(uuid4()) + user_id = str(uuid4()) - # Create upload file for current tenant - upload_file_current = self._create_upload_file() + upload_file_current = self._create_upload_file(db_session_with_containers, tenant_id, user_id) file_current = self._create_file( - related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + tenant_id, related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE ) - # Create upload file for other tenant (but don't add to cleanup list) - upload_file_other = UploadFile( - tenant_id=other_tenant_id, - storage_type=StorageType.LOCAL, - key="other_tenant_key", - name="other_file.txt", - size=1024, - extension=".txt", - mime_type="text/plain", - created_by_role=CreatorUserRole.ACCOUNT, - created_by=self.user_id, - created_at=datetime.now(UTC), - used=False, + upload_file_other = self._create_upload_file( + db_session_with_containers, + tenant_id, + user_id, + override_tenant_id=other_tenant_id, ) - upload_file_other.id = str(uuid4()) - self.session.add(upload_file_other) - self.session.flush() - - # Create file for other tenant but try to load with current tenant's loader file_other = self._create_file( - related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + tenant_id, + related_id=upload_file_other.id, + transfer_method=FileTransferMethod.LOCAL_FILE, + override_tenant_id=other_tenant_id, ) - # Should raise ValueError due to tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file_other]) + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) - assert "Upload file not found for id:" in str(context.value) + with pytest.raises(ValueError, match="Upload file not found for id:"): + loader.load_storage_keys([file_other]) - # Current tenant's file should still work - self.loader.load_storage_keys([file_current]) + # Current-tenant file still resolves correctly + loader.load_storage_keys([file_current]) assert file_current._storage_key == upload_file_current.key - def test_load_storage_keys_mixed_tenant_batch(self): - """Test batch with mixed tenant files (should fail on first mismatch).""" - # Create files for current tenant - upload_file_current = self._create_upload_file() + def test_load_storage_keys_mixed_tenant_batch(self, db_session_with_containers: Session): + """A batch containing a foreign-tenant file should fail on the mismatch.""" + tenant_id = str(uuid4()) + user_id = str(uuid4()) + + upload_file_current = self._create_upload_file(db_session_with_containers, tenant_id, user_id) file_current = self._create_file( - related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE + tenant_id, related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE ) - - # Create file for different tenant - other_tenant_id = str(uuid4()) file_other = self._create_file( - related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id + tenant_id, + related_id=str(uuid4()), + transfer_method=FileTransferMethod.LOCAL_FILE, + override_tenant_id=str(uuid4()), ) - # Should raise ValueError on tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file_current, file_other]) + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) + with pytest.raises(ValueError, match="Upload file not found for id:"): + loader.load_storage_keys([file_current, file_other]) - assert "Upload file not found for id:" in str(context.value) + def test_load_storage_keys_duplicate_file_ids(self, db_session_with_containers: Session): + """Duplicate file IDs in the batch should be handled gracefully.""" + tenant_id = str(uuid4()) + user_id = str(uuid4()) - def test_load_storage_keys_duplicate_file_ids(self): - """Test handling of duplicate file IDs in the batch.""" - # Create upload file - upload_file = self._create_upload_file() + upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id) + file1 = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + file2 = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - # Create two File objects with same related_id - file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) - file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + loader = StorageKeyLoader( + db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController() + ) + loader.load_storage_keys([file1, file2]) - # Should handle duplicates gracefully - self.loader.load_storage_keys([file1, file2]) - - # Both files should have the same storage key assert file1._storage_key == upload_file.key assert file2._storage_key == upload_file.key - def test_load_storage_keys_session_isolation(self): - """Test that the loader uses the provided session correctly.""" - # Create test data - upload_file = self._create_upload_file() - file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + def test_load_storage_keys_session_isolation(self, db_session_with_containers: Session): + """A loader backed by an uncommitted session should not see data from another session.""" + tenant_id = str(uuid4()) + user_id = str(uuid4()) - # Create loader with different session (same underlying connection) + upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id) + file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE) + # A loader with a fresh, separate session cannot see uncommitted rows from db_session_with_containers with Session(bind=db.engine) as other_session: other_loader = StorageKeyLoader( other_session, - self.tenant_id, + tenant_id, access_controller=DatabaseFileAccessController(), ) with pytest.raises(ValueError): diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index b745aed141..2fd289dfbc 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -6,7 +6,6 @@ from decimal import Decimal from uuid import uuid4 from graphon.nodes.human_input.entities import FormDefinition, UserAction - from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin from models.enums import ConversationFromSource, InvokeFrom diff --git a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py index 43915a204d..84c1d0ca41 100644 --- a/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py +++ b/api/tests/test_containers_integration_tests/libs/broadcast_channel/redis/test_sharded_channel.py @@ -8,6 +8,7 @@ Covers real Redis 7+ sharded pub/sub interactions including: - Resource cleanup accounting via PUBSUB SHARDNUMSUB """ +import socket import threading import time import uuid @@ -356,10 +357,17 @@ class TestShardedRedisBroadcastChannelClusterIntegration: def _get_test_topic_name(cls) -> str: return f"test_sharded_cluster_topic_{uuid.uuid4()}" + @staticmethod + def _resolve_announced_ip(host: str) -> str: + """Resolve the container host name to a literal IP accepted by Redis cluster config.""" + return socket.getaddrinfo(host, None, type=socket.SOCK_STREAM)[0][4][0] + @staticmethod def _ensure_single_node_cluster(host: str, port: int) -> None: + """Bootstrap a single-node cluster using a literal IP for Redis node advertisement.""" client = redis.Redis(host=host, port=port, decode_responses=False) - client.config_set("cluster-announce-ip", host) + announced_ip = TestShardedRedisBroadcastChannelClusterIntegration._resolve_announced_ip(host) + client.config_set("cluster-announce-ip", announced_ip) client.config_set("cluster-announce-port", port) slots = client.execute_command("CLUSTER", "SLOTS") if not slots: diff --git a/api/tests/test_containers_integration_tests/models/test_account.py b/api/tests/test_containers_integration_tests/models/test_account.py index 078dc0e8de..6fd6716cbb 100644 --- a/api/tests/test_containers_integration_tests/models/test_account.py +++ b/api/tests/test_containers_integration_tests/models/test_account.py @@ -1,79 +1,202 @@ -# import secrets +""" +Integration tests for Account and Tenant model methods that interact with the database. -# import pytest -# from sqlalchemy import select -# from sqlalchemy.orm import Session -# from sqlalchemy.orm.exc import DetachedInstanceError +Migrated from unit_tests/models/test_account_models.py, replacing +@patch("models.account.db") mock patches with real PostgreSQL operations. -# from libs.datetime_utils import naive_utc_now -# from models.account import Account, Tenant, TenantAccountJoin +Covers: +- Account.current_tenant setter (sets _current_tenant and role from TenantAccountJoin) +- Account.set_tenant_id (resolves tenant + role from real join row) +- Account.get_by_openid (AccountIntegrate lookup then Account fetch) +- Tenant.get_accounts (returns accounts linked via TenantAccountJoin) +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy import delete +from sqlalchemy.orm import Session + +from models.account import Account, AccountIntegrate, Tenant, TenantAccountJoin, TenantAccountRole -# @pytest.fixture -# def session(db_session_with_containers): -# with Session(db_session_with_containers.get_bind()) as session: -# yield session +def _cleanup_tracked_rows(db_session: Session, tracked: list) -> None: + """Delete rows tracked during the test so committed state does not leak into the DB. + + Rolls back any pending (uncommitted) session state first, then issues DELETE + statements by primary key for each tracked entity (in reverse creation order) + and commits. This cleans up rows created via either flush() or commit(). + """ + db_session.rollback() + for entity in reversed(tracked): + db_session.execute(delete(type(entity)).where(type(entity).id == entity.id)) + db_session.commit() -# @pytest.fixture -# def account(session): -# account = Account( -# name="test account", -# email=f"test_{secrets.token_hex(8)}@example.com", -# ) -# session.add(account) -# session.commit() -# return account +def _build_tenant() -> Tenant: + return Tenant(name=f"Tenant {uuid4()}") -# @pytest.fixture -# def tenant(session): -# tenant = Tenant(name="test tenant") -# session.add(tenant) -# session.commit() -# return tenant +def _build_account(email_prefix: str = "account") -> Account: + return Account( + name=f"Account {uuid4()}", + email=f"{email_prefix}_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) -# @pytest.fixture -# def tenant_account_join(session, account, tenant): -# tenant_join = TenantAccountJoin(account_id=account.id, tenant_id=tenant.id) -# session.add(tenant_join) -# session.commit() -# yield tenant_join -# session.delete(tenant_join) -# session.commit() +class _DBTrackingTestBase: + """Base class providing a tracker list and shared row factories for account/tenant tests.""" + + _tracked: list + + @pytest.fixture(autouse=True) + def _setup_cleanup(self, db_session_with_containers: Session) -> Generator[None, None, None]: + self._tracked = [] + yield + _cleanup_tracked_rows(db_session_with_containers, self._tracked) + + def _create_tenant(self, db_session: Session) -> Tenant: + tenant = _build_tenant() + db_session.add(tenant) + db_session.flush() + self._tracked.append(tenant) + return tenant + + def _create_account(self, db_session: Session, email_prefix: str = "account") -> Account: + account = _build_account(email_prefix) + db_session.add(account) + db_session.flush() + self._tracked.append(account) + return account + + def _create_join( + self, db_session: Session, tenant_id: str, account_id: str, role: TenantAccountRole, current: bool = True + ) -> TenantAccountJoin: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id, role=role, current=current) + db_session.add(join) + db_session.flush() + self._tracked.append(join) + return join -# class TestAccountTenant: -# def test_set_current_tenant_should_reload_tenant( -# self, -# db_session_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with Session(db_session_with_containers.get_bind(), expire_on_commit=True) as session: -# scoped_tenant = session.scalars(select(Tenant).where(Tenant.id == tenant.id)).one() -# account.current_tenant = scoped_tenant -# scoped_tenant.created_at = naive_utc_now() -# # session.commit() +class TestAccountCurrentTenantSetter(_DBTrackingTestBase): + """Integration tests for Account.current_tenant property setter.""" -# # Ensure the tenant used in assignment is detached. -# with pytest.raises(DetachedInstanceError): -# _ = scoped_tenant.name + def test_current_tenant_property_returns_cached_tenant(self, db_session_with_containers: Session) -> None: + """current_tenant getter returns the in-memory _current_tenant without DB access.""" + account = self._create_account(db_session_with_containers) + tenant = self._create_tenant(db_session_with_containers) + account._current_tenant = tenant -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id + assert account.current_tenant is tenant -# def test_set_tenant_id_should_load_tenant_as_not_expire( -# self, -# flask_app_with_containers, -# account, -# tenant, -# tenant_account_join, -# ): -# with flask_app_with_containers.test_request_context(): -# account.set_tenant_id(tenant.id) + def test_current_tenant_setter_sets_tenant_and_role_when_join_exists( + self, db_session_with_containers: Session + ) -> None: + """Setting current_tenant loads the join row and assigns role when relationship exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.OWNER) + db_session_with_containers.commit() -# assert account._current_tenant.id == tenant.id -# assert account._current_tenant.id == tenant.id + account.current_tenant = tenant + + assert account._current_tenant is not None + assert account._current_tenant.id == tenant.id + assert account.role == TenantAccountRole.OWNER + + def test_current_tenant_setter_sets_none_when_no_join_exists(self, db_session_with_containers: Session) -> None: + """Setting current_tenant results in _current_tenant=None when no join row exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + db_session_with_containers.commit() + + account.current_tenant = tenant + + assert account._current_tenant is None + + +class TestAccountSetTenantId(_DBTrackingTestBase): + """Integration tests for Account.set_tenant_id method.""" + + def test_set_tenant_id_sets_tenant_and_role_when_relationship_exists( + self, db_session_with_containers: Session + ) -> None: + """set_tenant_id loads the tenant and assigns role when a join row exists.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + self._create_join(db_session_with_containers, tenant.id, account.id, TenantAccountRole.ADMIN) + db_session_with_containers.commit() + + account.set_tenant_id(tenant.id) + + assert account._current_tenant is not None + assert account._current_tenant.id == tenant.id + assert account.role == TenantAccountRole.ADMIN + + def test_set_tenant_id_does_not_set_tenant_when_no_relationship_exists( + self, db_session_with_containers: Session + ) -> None: + """set_tenant_id does nothing when no join row matches the tenant.""" + tenant = self._create_tenant(db_session_with_containers) + account = self._create_account(db_session_with_containers) + db_session_with_containers.commit() + + account.set_tenant_id(tenant.id) + + assert account._current_tenant is None + + +class TestAccountGetByOpenId(_DBTrackingTestBase): + """Integration tests for Account.get_by_openid class method.""" + + def test_get_by_openid_returns_account_when_integrate_exists(self, db_session_with_containers: Session) -> None: + """get_by_openid returns the Account when a matching AccountIntegrate row exists.""" + account = self._create_account(db_session_with_containers, email_prefix="openid") + provider = "google" + open_id = f"google_{uuid4()}" + + integrate = AccountIntegrate( + account_id=account.id, + provider=provider, + open_id=open_id, + encrypted_token="token", + ) + db_session_with_containers.add(integrate) + db_session_with_containers.flush() + self._tracked.append(integrate) + + result = Account.get_by_openid(provider, open_id) + + assert result is not None + assert result.id == account.id + + def test_get_by_openid_returns_none_when_no_integrate_exists(self, db_session_with_containers: Session) -> None: + """get_by_openid returns None when no AccountIntegrate row matches.""" + result = Account.get_by_openid("github", f"github_{uuid4()}") + + assert result is None + + +class TestTenantGetAccounts(_DBTrackingTestBase): + """Integration tests for Tenant.get_accounts method.""" + + def test_get_accounts_returns_linked_accounts(self, db_session_with_containers: Session) -> None: + """get_accounts returns all accounts linked to the tenant via TenantAccountJoin.""" + tenant = self._create_tenant(db_session_with_containers) + account1 = self._create_account(db_session_with_containers, email_prefix="tenant_member") + account2 = self._create_account(db_session_with_containers, email_prefix="tenant_member") + self._create_join(db_session_with_containers, tenant.id, account1.id, TenantAccountRole.OWNER, current=False) + self._create_join(db_session_with_containers, tenant.id, account2.id, TenantAccountRole.NORMAL, current=False) + + accounts = tenant.get_accounts() + + assert len(accounts) == 2 + account_ids = {a.id for a in accounts} + assert account1.id in account_ids + assert account2.id in account_ids diff --git a/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py b/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py new file mode 100644 index 0000000000..f10f519e25 --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_conversation_message_inputs.py @@ -0,0 +1,149 @@ +""" +Integration tests for Conversation.inputs and Message.inputs tenant resolution. + +Migrated from unit_tests/models/test_model.py, replacing db.session.scalar monkeypatching +with a real App in PostgreSQL so the _resolve_app_tenant_id lookup executes against the DB. +""" + +from collections.abc import Generator +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from core.workflow.file_reference import build_file_reference +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod +from models.model import App, AppMode, Conversation, Message + + +def _build_local_file_mapping(record_id: str, *, tenant_id: str | None = None) -> dict: + mapping: dict = { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "reference": build_file_reference(record_id=record_id), + "type": "document", + "filename": "example.txt", + "extension": ".txt", + "mime_type": "text/plain", + "size": 1, + } + if tenant_id is not None: + mapping["tenant_id"] = tenant_id + return mapping + + +class TestConversationMessageInputsTenantResolution: + """Integration tests for Conversation/Message.inputs tenant resolution via real DB lookup.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_app(self, db_session: Session) -> App: + tenant_id = str(uuid4()) + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.CHAT, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=str(uuid4()), + updated_by=str(uuid4()), + ) + db_session.add(app) + db_session.flush() + return app + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_resolves_tenant_via_db_for_local_file( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs resolves tenant_id from real App row when file mapping has no tenant_id.""" + app = self._create_app(db_session_with_containers) + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = {"file": _build_local_file_mapping("upload-1")} + + restored_inputs = owner.inputs + + # The tenant_id should come from the real App row in the DB + assert restored_inputs["file"] == {"tenant_id": app.tenant_id, "upload_file_id": "upload-1"} + assert len(build_calls) == 1 + assert build_calls[0][1] == app.tenant_id + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_uses_serialized_tenant_id_skipping_db_lookup( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs uses tenant_id from the file mapping payload without hitting the DB.""" + app = self._create_app(db_session_with_containers) + payload_tenant_id = "tenant-from-payload" + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = {"file": _build_local_file_mapping("upload-1", tenant_id=payload_tenant_id)} + + restored_inputs = owner.inputs + + assert restored_inputs["file"] == {"tenant_id": payload_tenant_id, "upload_file_id": "upload-1"} + assert len(build_calls) == 1 + assert build_calls[0][1] == payload_tenant_id + + @pytest.mark.parametrize("owner_cls", [Conversation, Message]) + def test_inputs_resolves_tenant_for_file_list( + self, + db_session_with_containers: Session, + owner_cls: type, + ) -> None: + """Inputs resolves tenant_id for a list of file mappings.""" + app = self._create_app(db_session_with_containers) + build_calls: list[tuple[dict, str]] = [] + + def fake_build_from_mapping( + *, mapping, tenant_id, config=None, strict_type_validation=False, access_controller + ): + build_calls.append((dict(mapping), tenant_id)) + return {"tenant_id": tenant_id, "upload_file_id": mapping.get("upload_file_id")} + + with patch("factories.file_factory.build_from_mapping", fake_build_from_mapping): + owner = owner_cls(app_id=app.id) + owner.inputs = { + "files": [ + _build_local_file_mapping("upload-1"), + _build_local_file_mapping("upload-2"), + ] + } + + restored_inputs = owner.inputs + + assert len(build_calls) == 2 + assert all(call[1] == app.tenant_id for call in build_calls) + assert restored_inputs["files"] == [ + {"tenant_id": app.tenant_id, "upload_file_id": "upload-1"}, + {"tenant_id": app.tenant_id, "upload_file_id": "upload-2"}, + ] diff --git a/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py b/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py new file mode 100644 index 0000000000..6352f815df --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_conversation_status_count.py @@ -0,0 +1,314 @@ +""" +Integration tests for Conversation.status_count and Site.generate_code model properties. + +Migrated from unit_tests/models/test_app_models.py TestConversationStatusCount and +test_site_generate_code, replacing db.session.scalars mocks with real PostgreSQL queries. +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from graphon.enums import WorkflowExecutionStatus +from models.enums import ConversationFromSource, InvokeFrom +from models.model import App, AppMode, Conversation, Message, Site +from models.workflow import Workflow, WorkflowRun, WorkflowRunTriggeredFrom, WorkflowType + + +class TestConversationStatusCount: + """Integration tests for Conversation.status_count property.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App: + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.ADVANCED_CHAT, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=created_by, + updated_by=created_by, + ) + db_session.add(app) + db_session.flush() + return app + + def _create_conversation(self, db_session: Session, app: App) -> Conversation: + conversation = Conversation( + app_id=app.id, + mode=app.mode, + name=f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.WEB_APP, + from_source=ConversationFromSource.API, + dialogue_count=0, + is_deleted=False, + ) + conversation.inputs = {} + db_session.add(conversation) + db_session.flush() + return conversation + + def _create_workflow(self, db_session: Session, app: App, created_by: str) -> Workflow: + workflow = Workflow( + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.CHAT, + version="draft", + graph="{}", + created_by=created_by, + ) + workflow._features = "{}" + db_session.add(workflow) + db_session.flush() + return workflow + + def _create_workflow_run( + self, db_session: Session, app: App, workflow: Workflow, status: WorkflowExecutionStatus, created_by: str + ) -> WorkflowRun: + run = WorkflowRun( + tenant_id=app.tenant_id, + app_id=app.id, + workflow_id=workflow.id, + type=WorkflowType.CHAT, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + version="draft", + status=status, + created_by_role="account", + created_by=created_by, + ) + db_session.add(run) + db_session.flush() + return run + + def _create_message( + self, db_session: Session, app: App, conversation: Conversation, workflow_run_id: str | None = None + ) -> Message: + message = Message( + app_id=app.id, + conversation_id=conversation.id, + _inputs={}, + query="Test query", + message={"role": "user", "content": "Test query"}, + answer="Test answer", + model_provider="openai", + model_id="gpt-4", + message_tokens=10, + message_unit_price=0, + answer_tokens=10, + answer_unit_price=0, + total_price=0, + currency="USD", + from_source=ConversationFromSource.API, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id=workflow_run_id, + ) + db_session.add(message) + db_session.flush() + return message + + def test_status_count_returns_none_when_no_messages(self, db_session_with_containers: Session) -> None: + """status_count returns None when conversation has no messages with workflow_run_id.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + + result = conversation.status_count + + assert result is None + + def test_status_count_returns_none_when_messages_have_no_workflow_run_id( + self, db_session_with_containers: Session + ) -> None: + """status_count returns None when messages exist but none have workflow_run_id.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=None) + + result = conversation.status_count + + assert result is None + + def test_status_count_counts_succeeded_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts succeeded workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 1 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + assert result["paused"] == 0 + + def test_status_count_counts_failed_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts failed workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.FAILED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 0 + assert result["failed"] == 1 + assert result["partial_success"] == 0 + assert result["paused"] == 0 + + def test_status_count_counts_paused_workflow_run(self, db_session_with_containers: Session) -> None: + """status_count correctly counts paused workflow runs.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + run = self._create_workflow_run( + db_session_with_containers, app, workflow, WorkflowExecutionStatus.PAUSED, created_by + ) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 0 + assert result["failed"] == 0 + assert result["partial_success"] == 0 + assert result["paused"] == 1 + + def test_status_count_multiple_statuses(self, db_session_with_containers: Session) -> None: + """status_count counts multiple workflow runs with different statuses.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, app, created_by) + + for status in [ + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + WorkflowExecutionStatus.PAUSED, + ]: + run = self._create_workflow_run(db_session_with_containers, app, workflow, status, created_by) + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=run.id) + + result = conversation.status_count + + assert result is not None + assert result["success"] == 1 + assert result["failed"] == 1 + assert result["partial_success"] == 1 + assert result["paused"] == 1 + + def test_status_count_filters_workflow_runs_by_app_id(self, db_session_with_containers: Session) -> None: + """status_count excludes workflow runs belonging to a different app.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + app = self._create_app(db_session_with_containers, tenant_id, created_by) + other_app = self._create_app(db_session_with_containers, tenant_id, created_by) + conversation = self._create_conversation(db_session_with_containers, app) + workflow = self._create_workflow(db_session_with_containers, other_app, created_by) + + # Workflow run belongs to other_app, not app + other_run = self._create_workflow_run( + db_session_with_containers, other_app, workflow, WorkflowExecutionStatus.SUCCEEDED, created_by + ) + # Message references that run but is in a conversation under app + self._create_message(db_session_with_containers, app, conversation, workflow_run_id=other_run.id) + + result = conversation.status_count + + # The run should be excluded because app_id filter doesn't match + assert result is not None + assert result["success"] == 0 + + +class TestSiteGenerateCode: + """Integration tests for Site.generate_code static method.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def test_generate_code_returns_string_of_correct_length(self, db_session_with_containers: Session) -> None: + """Site.generate_code returns a code string of the requested length.""" + code = Site.generate_code(8) + + assert isinstance(code, str) + assert len(code) == 8 + + def test_generate_code_avoids_duplicates(self, db_session_with_containers: Session) -> None: + """Site.generate_code returns a code not already in use.""" + tenant_id = str(uuid4()) + app = App( + tenant_id=tenant_id, + name="Test App", + mode=AppMode.CHAT, + enable_site=True, + enable_api=False, + is_demo=False, + is_public=False, + is_universal=False, + created_by=str(uuid4()), + updated_by=str(uuid4()), + ) + db_session_with_containers.add(app) + db_session_with_containers.flush() + + site = Site( + app_id=app.id, + title="Test Site", + default_language="en-US", + customize_token_strategy="not_allow", + ) + # Set an explicit code so generate_code must avoid it + site.code = "AAAAAAAA" + db_session_with_containers.add(site) + db_session_with_containers.flush() + + code = Site.generate_code(8) + + assert isinstance(code, str) + assert len(code) == 8 + assert code != site.code diff --git a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py index 8aec6b6acc..b325c97f7d 100644 --- a/api/tests/test_containers_integration_tests/models/test_types_enum_text.py +++ b/api/tests/test_containers_integration_tests/models/test_types_enum_text.py @@ -4,13 +4,13 @@ from typing import Any, NamedTuple import pytest import sqlalchemy as sa -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import exc as sa_exc -from sqlalchemy import insert +from sqlalchemy import insert, select from sqlalchemy.engine import Connection, Engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column from sqlalchemy.sql.sqltypes import VARCHAR +from graphon.model_runtime.entities.model_entities import ModelType from models.types import EnumText _USER_TABLE = "enum_text_users" @@ -137,12 +137,12 @@ class TestEnumText: session.commit() with Session(engine_with_containers) as session: - user = session.query(_User).where(_User.id == admin_user_id).first() + user = session.scalar(select(_User).where(_User.id == admin_user_id).limit(1)) assert user.user_type == _UserType.admin assert user.user_type_nullable is None with Session(engine_with_containers) as session: - user = session.query(_User).where(_User.id == normal_user_id).first() + user = session.scalar(select(_User).where(_User.id == normal_user_id).limit(1)) assert user.user_type == _UserType.normal assert user.user_type_nullable == _UserType.normal @@ -206,7 +206,7 @@ class TestEnumText: with pytest.raises(ValueError) as exc: with Session(engine_with_containers) as session: - _user = session.query(_User).where(_User.id == 1).first() + _user = session.scalar(select(_User).where(_User.id == 1).limit(1)) assert str(exc.value) == "'invalid' is not a valid _UserType" @@ -222,7 +222,7 @@ class TestEnumText: session.commit() with Session(engine_with_containers) as session: - records = session.query(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id).all() + records = session.scalars(select(_LegacyModelTypeRecord).order_by(_LegacyModelTypeRecord.id)).all() assert [record.model_type for record in records] == [ ModelType.LLM, diff --git a/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py b/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py new file mode 100644 index 0000000000..14c2263110 --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_workflow_node_execution_model.py @@ -0,0 +1,170 @@ +""" +Integration tests for WorkflowNodeExecutionModel.created_by_account and .created_by_end_user. + +Migrated from unit_tests/models/test_workflow_trigger_log.py, replacing +monkeypatch.setattr(db.session, "scalar", ...) with real Account/EndUser rows +persisted in PostgreSQL so the db.session.get() call executes against the DB. +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.account import Account +from models.enums import CreatorUserRole +from models.model import App, AppMode, EndUser +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom + + +class TestWorkflowNodeExecutionModelCreatedBy: + """Integration tests for WorkflowNodeExecutionModel creator lookup properties.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def _create_account(self, db_session: Session) -> Account: + account = Account( + name="Test Account", + email=f"test_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session.add(account) + db_session.flush() + return account + + def _create_end_user(self, db_session: Session, tenant_id: str, app_id: str) -> EndUser: + end_user = EndUser( + tenant_id=tenant_id, + app_id=app_id, + type="service_api", + external_user_id=f"ext-{uuid4()}", + name="End User", + session_id=f"session-{uuid4()}", + ) + end_user.is_anonymous = False + db_session.add(end_user) + db_session.flush() + return end_user + + def _create_app(self, db_session: Session, tenant_id: str, created_by: str) -> App: + app = App( + tenant_id=tenant_id, + name=f"App {uuid4()}", + mode=AppMode.WORKFLOW, + enable_site=False, + enable_api=True, + is_demo=False, + is_public=False, + is_universal=False, + created_by=created_by, + updated_by=created_by, + ) + db_session.add(app) + db_session.flush() + return app + + def _make_execution( + self, tenant_id: str, app_id: str, created_by_role: str, created_by: str + ) -> WorkflowNodeExecutionModel: + return WorkflowNodeExecutionModel( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=str(uuid4()), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=None, + index=1, + predecessor_node_id=None, + node_execution_id=None, + node_id="n1", + node_type="start", + title="Start", + inputs=None, + process_data=None, + outputs=None, + status="succeeded", + error=None, + elapsed_time=0.0, + execution_metadata=None, + created_by_role=created_by_role, + created_by=created_by, + ) + + def test_created_by_account_returns_account_when_role_is_account(self, db_session_with_containers: Session) -> None: + """created_by_account returns the Account row when role is ACCOUNT.""" + account = self._create_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, str(uuid4()), account.id) + + execution = self._make_execution( + tenant_id=app.tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=account.id, + ) + + result = execution.created_by_account + + assert result is not None + assert result.id == account.id + + def test_created_by_account_returns_none_when_role_is_end_user(self, db_session_with_containers: Session) -> None: + """created_by_account returns None when role is END_USER, even if an Account exists.""" + account = self._create_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, str(uuid4()), account.id) + + execution = self._make_execution( + tenant_id=app.tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.END_USER.value, + created_by=account.id, + ) + + result = execution.created_by_account + + assert result is None + + def test_created_by_end_user_returns_end_user_when_role_is_end_user( + self, db_session_with_containers: Session + ) -> None: + """created_by_end_user returns the EndUser row when role is END_USER.""" + account = self._create_account(db_session_with_containers) + tenant_id = str(uuid4()) + app = self._create_app(db_session_with_containers, tenant_id, account.id) + end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id) + + execution = self._make_execution( + tenant_id=tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.END_USER.value, + created_by=end_user.id, + ) + + result = execution.created_by_end_user + + assert result is not None + assert result.id == end_user.id + + def test_created_by_end_user_returns_none_when_role_is_account(self, db_session_with_containers: Session) -> None: + """created_by_end_user returns None when role is ACCOUNT, even if an EndUser exists.""" + account = self._create_account(db_session_with_containers) + tenant_id = str(uuid4()) + app = self._create_app(db_session_with_containers, tenant_id, account.id) + end_user = self._create_end_user(db_session_with_containers, tenant_id, app.id) + + execution = self._make_execution( + tenant_id=tenant_id, + app_id=app.id, + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by=end_user.id, + ) + + result = execution.created_by_end_user + + assert result is None diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index a68b3a08c7..641399c7f9 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index d28cfda159..d9828e19c5 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -2,27 +2,31 @@ from __future__ import annotations +import secrets from dataclasses import dataclass, field from datetime import datetime, timedelta from unittest.mock import Mock from uuid import uuid4 import pytest +from sqlalchemy import Engine, delete, select +from sqlalchemy.orm import Session, sessionmaker + +from core.workflow.human_input_adapter import DeliveryMethodType +from extensions.ext_storage import storage from graphon.entities import WorkflowExecution from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType from graphon.enums import WorkflowExecutionStatus from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from sqlalchemy import Engine, delete, select -from sqlalchemy.orm import Session, sessionmaker - -from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( + BackstageRecipientPayload, HumanInputDelivery, HumanInputForm, HumanInputFormRecipient, + RecipientType, ) from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -220,7 +224,6 @@ class TestDeleteRunsWithRelated: created_by=test_scope.user_id, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", @@ -280,7 +283,6 @@ class TestCountRunsWithRelated: created_by=test_scope.user_id, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", @@ -544,7 +546,6 @@ class TestPrivateWorkflowPauseEntity: status=WorkflowExecutionStatus.RUNNING, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", @@ -574,7 +575,6 @@ class TestPrivateWorkflowPauseEntity: ) state_key = f"workflow-state-{uuid4()}.json" pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=state_key, @@ -606,7 +606,6 @@ class TestPrivateWorkflowPauseEntity: ) state_key = f"workflow-state-{uuid4()}.json" pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=state_key, @@ -633,12 +632,12 @@ class TestPrivateWorkflowPauseEntity: class TestBuildHumanInputRequiredReason: """Integration tests for _build_human_input_required_reason using real DB models.""" - def test_builds_reason_from_form_definition( + def test_prefers_standalone_web_app_token_when_available( self, db_session_with_containers: Session, test_scope: _TestScope, ) -> None: - """Build the graph pause reason from the stored form definition.""" + """Use the public standalone web-app token for service API payloads.""" expiration_time = naive_utc_now() form_definition = FormDefinition( @@ -665,6 +664,40 @@ class TestBuildHumanInputRequiredReason: db_session_with_containers.add(form_model) db_session_with_containers.flush() + delivery = HumanInputDelivery( + form_id=form_model.id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload="{}", + ) + db_session_with_containers.add(delivery) + db_session_with_containers.flush() + + backstage_access_token = secrets.token_urlsafe(8) + backstage_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=backstage_access_token, + ) + console_access_token = secrets.token_urlsafe(8) + console_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.CONSOLE, + recipient_payload="{}", + access_token=console_access_token, + ) + web_app_access_token = secrets.token_urlsafe(8) + web_app_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + recipient_payload="{}", + access_token=web_app_access_token, + ) + db_session_with_containers.add_all([backstage_recipient, console_recipient, web_app_recipient]) + db_session_with_containers.flush() # Create a pause so the reason has a valid pause_id workflow_run = _create_workflow_run( db_session_with_containers, @@ -672,7 +705,6 @@ class TestBuildHumanInputRequiredReason: status=WorkflowExecutionStatus.RUNNING, ) pause = WorkflowPause( - id=str(uuid4()), workflow_id=test_scope.workflow_id, workflow_run_id=workflow_run.id, state_object_key=f"workflow-state-{uuid4()}.json", @@ -694,8 +726,15 @@ class TestBuildHumanInputRequiredReason: # Refresh to ensure we have DB-round-tripped objects db_session_with_containers.refresh(form_model) db_session_with_containers.refresh(reason_model) + db_session_with_containers.refresh(backstage_recipient) + db_session_with_containers.refresh(console_recipient) + db_session_with_containers.refresh(web_app_recipient) - reason = _build_human_input_required_reason(reason_model, form_model) + reason = _build_human_input_required_reason( + reason_model, + form_model, + [backstage_recipient, console_recipient, web_app_recipient], + ) assert isinstance(reason, HumanInputRequired) assert reason.node_title == "Ask Name" @@ -703,3 +742,92 @@ class TestBuildHumanInputRequiredReason: assert reason.inputs[0].output_variable_name == "name" assert reason.actions[0].id == "approve" assert reason.resolved_default_values == {"name": "Alice"} + assert not hasattr(reason, "form_token") + + def test_falls_back_to_console_token_when_web_app_token_missing( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Use the console token only when no standalone web-app token exists.""" + + expiration_time = naive_utc_now() + form_definition = FormDefinition( + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "Alice"}, + node_title="Ask Name", + display_in_ui=True, + ) + + form_model = HumanInputForm( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + workflow_run_id=str(uuid4()), + node_id="node-1", + form_definition=form_definition.model_dump_json(), + rendered_content="rendered", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + db_session_with_containers.add(form_model) + db_session_with_containers.flush() + + delivery = HumanInputDelivery( + form_id=form_model.id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload="{}", + ) + db_session_with_containers.add(delivery) + db_session_with_containers.flush() + + backstage_access_token = secrets.token_urlsafe(8) + backstage_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=backstage_access_token, + ) + console_access_token = secrets.token_urlsafe(8) + console_recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.CONSOLE, + recipient_payload="{}", + access_token=console_access_token, + ) + db_session_with_containers.add_all([backstage_recipient, console_recipient]) + db_session_with_containers.flush() + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.flush() + test_scope.state_keys.add(pause.state_object_key) + + reason_model = WorkflowPauseReason( + pause_id=pause.id, + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id=form_model.id, + node_id="node-1", + message="", + ) + db_session_with_containers.add(reason_model) + db_session_with_containers.commit() + + reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient, console_recipient]) + + assert isinstance(reason, HumanInputRequired) + assert not hasattr(reason, "form_token") diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index 7f44eb6ca3..54b7afc018 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -12,11 +12,11 @@ from decimal import Decimal from uuid import uuid4 import pytest -from graphon.nodes.human_input.entities import FormDefinition, UserAction -from graphon.nodes.human_input.enums import HumanInputFormStatus from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ConversationFromSource, InvokeFrom @@ -271,7 +271,7 @@ def _create_recipient( def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: - from core.workflow.human_input_compat import DeliveryMethodType + from core.workflow.human_input_adapter import DeliveryMethodType from models.human_input import ConsoleDeliveryPayload delivery = HumanInputDelivery( diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py new file mode 100644 index 0000000000..fa78f1c28b --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -0,0 +1,395 @@ +"""Testcontainers integration tests for SQLAlchemyWorkflowNodeExecutionRepository.""" + +from __future__ import annotations + +import json +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker + +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.model_runtime.utils.encoders import jsonable_encoder +from models.account import Account, Tenant +from models.enums import CreatorUserRole +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom + + +def _create_account_with_tenant(session: Session) -> Account: + tenant = Tenant(name="Test Workspace") + session.add(tenant) + session.flush() + + account = Account(name="test", email=f"test-{uuid4()}@example.com") + session.add(account) + session.flush() + + account._current_tenant = tenant + return account + + +def _make_repo(session: Session, account: Account, app_id: str) -> SQLAlchemyWorkflowNodeExecutionRepository: + engine = session.get_bind() + assert isinstance(engine, Engine) + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=sessionmaker(bind=engine, expire_on_commit=False), + user=account, + app_id=app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def _create_node_execution_model( + session: Session, + *, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_run_id: str, + index: int = 1, + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING, +) -> WorkflowNodeExecutionModel: + model = WorkflowNodeExecutionModel( + id=str(uuid4()), + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=workflow_run_id, + index=index, + predecessor_node_id=None, + node_execution_id=str(uuid4()), + node_id=f"node-{index}", + node_type=BuiltinNodeTypes.START, + title=f"Test Node {index}", + inputs='{"input_key": "input_value"}', + process_data='{"process_key": "process_value"}', + outputs='{"output_key": "output_value"}', + status=status, + error=None, + elapsed_time=1.5, + execution_metadata="{}", + created_at=datetime.now(), + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + finished_at=None, + ) + session.add(model) + session.flush() + return model + + +class TestSave: + def test_save_new_record(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + index=1, + predecessor_node_id=None, + node_id="node-1", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs={"input_key": "input_value"}, + process_data={"process_key": "process_value"}, + outputs={"result": "success"}, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=1.5, + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100}, + created_at=datetime.now(), + finished_at=None, + ) + + repo.save(execution) + + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session: + saved = verify_session.get(WorkflowNodeExecutionModel, execution.id) + assert saved is not None + assert saved.tenant_id == account.current_tenant_id + assert saved.app_id == app_id + assert saved.node_id == "node-1" + assert saved.status == WorkflowNodeExecutionStatus.RUNNING + + def test_save_updates_existing_record(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + repo = _make_repo(db_session_with_containers, account, str(uuid4())) + + execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + index=1, + predecessor_node_id=None, + node_id="node-1", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs=None, + process_data=None, + outputs=None, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=0.0, + metadata=None, + created_at=datetime.now(), + finished_at=None, + ) + + repo.save(execution) + + execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + execution.elapsed_time = 2.5 + repo.save(execution) + + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + with sessionmaker(bind=engine, expire_on_commit=False)() as verify_session: + saved = verify_session.get(WorkflowNodeExecutionModel, execution.id) + assert saved is not None + assert saved.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert saved.elapsed_time == 2.5 + + +class TestGetByWorkflowExecution: + def test_returns_executions_ordered(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + tenant_id = account.current_tenant_id + app_id = str(uuid4()) + workflow_id = str(uuid4()) + workflow_run_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + ) + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=2, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + ) + db_session_with_containers.commit() + + order_config = OrderConfig(order_by=["index"], order_direction="desc") + result = repo.get_by_workflow_execution( + workflow_execution_id=workflow_run_id, + order_config=order_config, + ) + + assert len(result) == 2 + assert result[0].index == 2 + assert result[1].index == 1 + assert all(isinstance(r, WorkflowNodeExecution) for r in result) + + def test_excludes_paused_executions(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + tenant_id = account.current_tenant_id + app_id = str(uuid4()) + workflow_id = str(uuid4()) + workflow_run_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=1, + status=WorkflowNodeExecutionStatus.RUNNING, + ) + _create_node_execution_model( + db_session_with_containers, + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow_id, + workflow_run_id=workflow_run_id, + index=2, + status=WorkflowNodeExecutionStatus.PAUSED, + ) + db_session_with_containers.commit() + + result = repo.get_by_workflow_execution(workflow_execution_id=workflow_run_id) + + assert len(result) == 1 + assert result[0].index == 1 + + +class TestToDbModel: + def test_converts_domain_to_db_model(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + domain_model = WorkflowNodeExecution( + id="test-id", + workflow_id="test-workflow-id", + node_execution_id="test-node-execution-id", + workflow_execution_id="test-workflow-run-id", + index=1, + predecessor_node_id="test-predecessor-id", + node_id="test-node-id", + node_type=BuiltinNodeTypes.START, + title="Test Node", + inputs={"input_key": "input_value"}, + process_data={"process_key": "process_value"}, + outputs={"output_key": "output_value"}, + status=WorkflowNodeExecutionStatus.RUNNING, + error=None, + elapsed_time=1.5, + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 100, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: Decimal("0.0"), + }, + created_at=datetime.now(), + finished_at=None, + ) + + db_model = repo._to_db_model(domain_model) + + assert isinstance(db_model, WorkflowNodeExecutionModel) + assert db_model.id == domain_model.id + assert db_model.tenant_id == account.current_tenant_id + assert db_model.app_id == app_id + assert db_model.workflow_id == domain_model.workflow_id + assert db_model.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + assert db_model.workflow_run_id == domain_model.workflow_execution_id + assert db_model.index == domain_model.index + assert db_model.predecessor_node_id == domain_model.predecessor_node_id + assert db_model.node_execution_id == domain_model.node_execution_id + assert db_model.node_id == domain_model.node_id + assert db_model.node_type == domain_model.node_type + assert db_model.title == domain_model.title + assert db_model.inputs_dict == domain_model.inputs + assert db_model.process_data_dict == domain_model.process_data + assert db_model.outputs_dict == domain_model.outputs + assert db_model.execution_metadata_dict == jsonable_encoder(domain_model.metadata) + assert db_model.status == domain_model.status + assert db_model.error == domain_model.error + assert db_model.elapsed_time == domain_model.elapsed_time + assert db_model.created_at == domain_model.created_at + assert db_model.created_by_role == CreatorUserRole.ACCOUNT + assert db_model.created_by == account.id + assert db_model.finished_at == domain_model.finished_at + + +class TestToDomainModel: + def test_converts_db_to_domain_model(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + repo = _make_repo(db_session_with_containers, account, app_id) + + inputs_dict = {"input_key": "input_value"} + process_data_dict = {"process_key": "process_value"} + outputs_dict = {"output_key": "output_value"} + metadata_dict = {str(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS): 100} + now = datetime.now() + + db_model = WorkflowNodeExecutionModel() + db_model.id = "test-id" + db_model.tenant_id = account.current_tenant_id + db_model.app_id = app_id + db_model.workflow_id = "test-workflow-id" + db_model.triggered_from = "workflow-run" + db_model.workflow_run_id = "test-workflow-run-id" + db_model.index = 1 + db_model.predecessor_node_id = "test-predecessor-id" + db_model.node_execution_id = "test-node-execution-id" + db_model.node_id = "test-node-id" + db_model.node_type = BuiltinNodeTypes.START + db_model.title = "Test Node" + db_model.inputs = json.dumps(inputs_dict) + db_model.process_data = json.dumps(process_data_dict) + db_model.outputs = json.dumps(outputs_dict) + db_model.status = WorkflowNodeExecutionStatus.RUNNING + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.execution_metadata = json.dumps(metadata_dict) + db_model.created_at = now + db_model.created_by_role = "account" + db_model.created_by = account.id + db_model.finished_at = None + + domain_model = repo._to_domain_model(db_model) + + assert isinstance(domain_model, WorkflowNodeExecution) + assert domain_model.id == "test-id" + assert domain_model.workflow_id == "test-workflow-id" + assert domain_model.workflow_execution_id == "test-workflow-run-id" + assert domain_model.index == 1 + assert domain_model.predecessor_node_id == "test-predecessor-id" + assert domain_model.node_execution_id == "test-node-execution-id" + assert domain_model.node_id == "test-node-id" + assert domain_model.node_type == BuiltinNodeTypes.START + assert domain_model.title == "Test Node" + assert domain_model.inputs == inputs_dict + assert domain_model.process_data == process_data_dict + assert domain_model.outputs == outputs_dict + assert domain_model.status == WorkflowNodeExecutionStatus.RUNNING + assert domain_model.error is None + assert domain_model.elapsed_time == 1.5 + assert domain_model.metadata == {WorkflowNodeExecutionMetadataKey(k): v for k, v in metadata_dict.items()} + assert domain_model.created_at == now + assert domain_model.finished_at is None + + def test_domain_model_without_offload_data(self, db_session_with_containers: Session) -> None: + account = _create_account_with_tenant(db_session_with_containers) + repo = _make_repo(db_session_with_containers, account, str(uuid4())) + + process_data = {"normal": "data"} + db_model = WorkflowNodeExecutionModel() + db_model.id = str(uuid4()) + db_model.tenant_id = account.current_tenant_id + db_model.app_id = str(uuid4()) + db_model.workflow_id = str(uuid4()) + db_model.triggered_from = "workflow-run" + db_model.workflow_run_id = None + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_execution_id = str(uuid4()) + db_model.node_id = "test-node-id" + db_model.node_type = "llm" + db_model.title = "Test Node" + db_model.inputs = None + db_model.process_data = json.dumps(process_data) + db_model.outputs = None + db_model.status = "succeeded" + db_model.error = None + db_model.elapsed_time = 1.5 + db_model.execution_metadata = "{}" + db_model.created_at = datetime.now() + db_model.created_by_role = "account" + db_model.created_by = account.id + db_model.finished_at = None + + domain_model = repo._to_domain_model(db_model) + + assert domain_model.process_data == process_data + assert domain_model.process_data_truncated is False + assert domain_model.get_truncated_process_data() is None diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py index c5e9201ee3..d6f0657380 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -7,12 +7,12 @@ from datetime import timedelta from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy import exc as sa_exc from sqlalchemy.orm import Session, sessionmaker +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py index 177fb95ff3..e71079829f 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -5,6 +5,7 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_service import ApiKeyAuthService @@ -31,7 +32,7 @@ class TestApiKeyAuthService: def mock_args(self, category, provider, mock_credentials) -> dict: return {"category": category, "provider": provider, "credentials": mock_credentials} - def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False): + def _create_binding(self, db_session: Session, *, tenant_id, category, provider, credentials=None, disabled=False): binding = DataSourceApiKeyAuthBinding( tenant_id=tenant_id, category=category, @@ -44,7 +45,7 @@ class TestApiKeyAuthService: return binding def test_get_provider_auth_list_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) db_session_with_containers.expire_all() @@ -56,14 +57,16 @@ class TestApiKeyAuthService: assert len(tenant_results) == 1 assert tenant_results[0].provider == provider - def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id): + def test_get_provider_auth_list_empty( + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id + ): result = ApiKeyAuthService.get_provider_auth_list(tenant_id) tenant_results = [r for r in result if r.tenant_id == tenant_id] assert tenant_results == [] def test_get_provider_auth_list_filters_disabled( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): self._create_binding( db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True @@ -78,7 +81,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") def test_create_provider_auth_success( - self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, + mock_encrypter, + mock_factory, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + mock_args, ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = True @@ -97,7 +106,7 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") def test_create_provider_auth_validation_failed( - self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, mock_factory, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_args ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = False @@ -112,7 +121,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") def test_create_provider_auth_encrypts_api_key( - self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, + mock_encrypter, + mock_factory, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + mock_args, ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = True @@ -128,7 +143,13 @@ class TestApiKeyAuthService: mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key) def test_get_auth_credentials_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials + self, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + category, + provider, + mock_credentials, ): self._create_binding( db_session_with_containers, @@ -144,14 +165,14 @@ class TestApiKeyAuthService: assert result == mock_credentials def test_get_auth_credentials_not_found( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) assert result is None def test_get_auth_credentials_json_parsing( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}} self._create_binding( @@ -169,7 +190,7 @@ class TestApiKeyAuthService: assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" def test_delete_provider_auth_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): binding = self._create_binding( db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider @@ -183,7 +204,9 @@ class TestApiKeyAuthService: remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() assert remaining is None - def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id): + def test_delete_provider_auth_not_found( + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id + ): # Should not raise when binding not found ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py index f48c6da690..e78fa27976 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -10,6 +10,7 @@ from uuid import uuid4 import httpx import pytest +from sqlalchemy.orm import Session from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_factory import ApiKeyAuthFactory @@ -114,7 +115,7 @@ class TestAuthIntegration: assert result2[0].tenant_id == tenant_id_2 def test_cross_tenant_access_prevention( - self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id_2, category ): result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index 42d587b7f7..327f14ddfe 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -12,6 +12,7 @@ from unittest.mock import create_autospec, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType @@ -273,7 +274,9 @@ class TestDocumentServicePauseDocument: "user_id": user_id, } - def test_pause_document_waiting_state_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_waiting_state_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful pause of document in waiting state. @@ -310,7 +313,7 @@ class TestDocumentServicePauseDocument: mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(expected_cache_key, "True") def test_pause_document_indexing_state_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful pause of document in indexing state. @@ -340,7 +343,9 @@ class TestDocumentServicePauseDocument: assert document.is_paused is True assert document.paused_by == mock_document_service_dependencies["user_id"] - def test_pause_document_parsing_state_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_parsing_state_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful pause of document in parsing state. @@ -367,7 +372,9 @@ class TestDocumentServicePauseDocument: db_session_with_containers.refresh(document) assert document.is_paused is True - def test_pause_document_completed_state_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_completed_state_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to pause completed document. @@ -396,7 +403,9 @@ class TestDocumentServicePauseDocument: db_session_with_containers.refresh(document) assert document.is_paused is False - def test_pause_document_error_state_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_error_state_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to pause document in error state. @@ -467,7 +476,9 @@ class TestDocumentServiceRecoverDocument: "recover_task": mock_task, } - def test_recover_document_paused_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_recover_document_paused_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful recovery of paused document. @@ -510,7 +521,9 @@ class TestDocumentServiceRecoverDocument: document.dataset_id, document.id ) - def test_recover_document_not_paused_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_recover_document_not_paused_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to recover non-paused document. @@ -590,7 +603,9 @@ class TestDocumentServiceRetryDocument: "user_id": user_id, } - def test_retry_document_single_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_retry_document_single_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful retry of single document. @@ -629,7 +644,9 @@ class TestDocumentServiceRetryDocument: dataset.id, [document.id], mock_document_service_dependencies["user_id"] ) - def test_retry_document_multiple_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_retry_document_multiple_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful retry of multiple documents. @@ -675,7 +692,7 @@ class TestDocumentServiceRetryDocument: ) def test_retry_document_concurrent_retry_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when document is already being retried. @@ -708,7 +725,7 @@ class TestDocumentServiceRetryDocument: assert document.indexing_status == IndexingStatus.ERROR def test_retry_document_missing_current_user_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when current_user is missing. @@ -794,7 +811,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: } def test_batch_update_document_status_enable_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch enabling of documents. @@ -844,7 +861,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: assert mock_document_service_dependencies["add_task"].delay.call_count == 2 def test_batch_update_document_status_disable_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch disabling of documents. @@ -886,7 +903,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_archive_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch archiving of documents. @@ -928,7 +945,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_unarchive_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch unarchiving of documents. @@ -970,7 +987,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["add_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_empty_list( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test handling of empty document list. @@ -996,7 +1013,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_not_called() def test_batch_update_document_status_document_indexing_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when document is being indexed. @@ -1073,7 +1090,7 @@ class TestDocumentServiceRenameDocument: "current_user": mock_current_user, } - def test_rename_document_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_success(self, db_session_with_containers: Session, mock_document_service_dependencies): """ Test successful document renaming. @@ -1111,7 +1128,9 @@ class TestDocumentServiceRenameDocument: assert result == document assert document.name == new_name - def test_rename_document_with_built_in_fields(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_with_built_in_fields( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test document renaming with built-in fields enabled. @@ -1154,7 +1173,9 @@ class TestDocumentServiceRenameDocument: assert document.doc_metadata["document_name"] == new_name assert document.doc_metadata["existing_key"] == "existing_value" - def test_rename_document_with_upload_file(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_with_upload_file( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test document renaming with associated upload file. @@ -1202,7 +1223,7 @@ class TestDocumentServiceRenameDocument: assert upload_file.name == new_name def test_rename_document_dataset_not_found_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when dataset is not found. @@ -1224,7 +1245,9 @@ class TestDocumentServiceRenameDocument: with pytest.raises(ValueError, match="Dataset not found"): DocumentService.rename_document(dataset_id, document_id, new_name) - def test_rename_document_not_found_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_not_found_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when document is not found. @@ -1251,7 +1274,9 @@ class TestDocumentServiceRenameDocument: with pytest.raises(ValueError, match="Document not found"): DocumentService.rename_document(dataset.id, document_id, new_name) - def test_rename_document_permission_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_permission_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when user lacks permission. diff --git a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py index 4e8255d8ed..e73c2afe7f 100644 --- a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py +++ b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py @@ -11,6 +11,7 @@ from uuid import uuid4 import pytest from redis import RedisError +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client from models.account import TenantAccountJoin @@ -122,7 +123,7 @@ class TestSyncAccountDeletion: mock_queue_task.assert_not_called() def test_sync_account_deletion_multiple_workspaces( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_ids = [str(uuid4()) for _ in range(3)] @@ -144,7 +145,7 @@ class TestSyncAccountDeletion: assert queued_workspace_ids == set(tenant_ids) def test_sync_account_deletion_no_workspaces( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: mock_config.ENTERPRISE_ENABLED = True @@ -155,7 +156,7 @@ class TestSyncAccountDeletion: mock_queue_task.assert_not_called() def test_sync_account_deletion_partial_failure( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_ids = [str(uuid4()) for _ in range(3)] @@ -180,7 +181,7 @@ class TestSyncAccountDeletion: assert mock_queue_task.call_count == 3 def test_sync_account_deletion_all_failures( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_id = str(uuid4()) diff --git a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py new file mode 100644 index 0000000000..49d06986fd --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from uuid import uuid4 + +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from models.account import TenantPluginPermission +from services.plugin.plugin_permission_service import PluginPermissionService + + +def _tenant_id() -> str: + return str(uuid4()) + + +def _get_permission(session: Session, tenant_id: str) -> TenantPluginPermission | None: + session.expire_all() + stmt = select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id) + return session.scalars(stmt).one_or_none() + + +def _count_permissions(session: Session, tenant_id: str) -> int: + stmt = select(func.count()).select_from(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id) + return session.scalar(stmt) or 0 + + +class TestGetPermission: + """Integration tests for PluginPermissionService.get_permission using testcontainers.""" + + def test_returns_permission_when_found(self, db_session_with_containers: Session): + tenant_id = _tenant_id() + permission = TenantPluginPermission( + tenant_id=tenant_id, + install_permission=TenantPluginPermission.InstallPermission.ADMINS, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + db_session_with_containers.add(permission) + db_session_with_containers.commit() + + result = PluginPermissionService.get_permission(tenant_id) + + assert result is not None + assert result.id == permission.id + assert result.tenant_id == tenant_id + assert result.install_permission == TenantPluginPermission.InstallPermission.ADMINS + assert result.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE + + def test_returns_none_when_not_found(self, db_session_with_containers: Session): + result = PluginPermissionService.get_permission(_tenant_id()) + + assert result is None + + +class TestChangePermission: + """Integration tests for PluginPermissionService.change_permission using testcontainers.""" + + def test_creates_new_permission_when_not_exists(self, db_session_with_containers: Session): + tenant_id = _tenant_id() + + result = PluginPermissionService.change_permission( + tenant_id, + TenantPluginPermission.InstallPermission.EVERYONE, + TenantPluginPermission.DebugPermission.EVERYONE, + ) + + permission = _get_permission(db_session_with_containers, tenant_id) + assert result is True + assert permission is not None + assert permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE + assert permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE + + def test_updates_existing_permission(self, db_session_with_containers: Session): + tenant_id = _tenant_id() + existing = TenantPluginPermission( + tenant_id=tenant_id, + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + db_session_with_containers.add(existing) + db_session_with_containers.commit() + + result = PluginPermissionService.change_permission( + tenant_id, + TenantPluginPermission.InstallPermission.ADMINS, + TenantPluginPermission.DebugPermission.ADMINS, + ) + + permission = _get_permission(db_session_with_containers, tenant_id) + assert result is True + assert permission is not None + assert permission.id == existing.id + assert permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS + assert permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS + assert _count_permissions(db_session_with_containers, tenant_id) == 1 diff --git a/api/tests/integration_tests/vdb/iris/__init__.py b/api/tests/test_containers_integration_tests/services/rag_pipeline/__init__.py similarity index 100% rename from api/tests/integration_tests/vdb/iris/__init__.py rename to api/tests/test_containers_integration_tests/services/rag_pipeline/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py b/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py new file mode 100644 index 0000000000..8fc1809a46 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py @@ -0,0 +1,255 @@ +""" +Integration tests for RagPipelineService methods that interact with the database. + +Migrated from unit_tests/services/rag_pipeline/test_rag_pipeline_service.py, replacing +db.session.scalar/commit/delete mocker patches with real PostgreSQL operations. + +Covers: +- get_pipeline: Dataset and Pipeline lookups +- update_customized_pipeline_template: find + unique-name check + commit +- delete_customized_pipeline_template: find + delete + commit +""" + +from collections.abc import Generator +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate +from models.enums import DataSourceType +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity +from services.rag_pipeline.rag_pipeline import RagPipelineService + + +class TestRagPipelineServiceGetPipeline: + """Integration tests for RagPipelineService.get_pipeline.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _make_service(self, flask_app_with_containers) -> RagPipelineService: + with ( + patch( + "services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository", + return_value=None, + ), + patch( + "services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=None, + ), + ): + session_factory = sessionmaker(bind=flask_app_with_containers.extensions["sqlalchemy"].engine) + return RagPipelineService(session_maker=session_factory) + + def _create_pipeline(self, db_session: Session, tenant_id: str, created_by: str) -> Pipeline: + pipeline = Pipeline( + tenant_id=tenant_id, + name=f"Pipeline {uuid4()}", + description="", + created_by=created_by, + ) + db_session.add(pipeline) + db_session.flush() + return pipeline + + def _create_dataset( + self, db_session: Session, tenant_id: str, created_by: str, pipeline_id: str | None = None + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=f"Dataset {uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by, + pipeline_id=pipeline_id, + ) + db_session.add(dataset) + db_session.flush() + return dataset + + def test_get_pipeline_raises_when_dataset_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline raises ValueError when dataset does not exist.""" + service = self._make_service(flask_app_with_containers) + + with pytest.raises(ValueError, match="Dataset not found"): + service.get_pipeline(tenant_id=str(uuid4()), dataset_id=str(uuid4())) + + def test_get_pipeline_raises_when_pipeline_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline raises ValueError when dataset exists but has no linked pipeline.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=None) + db_session_with_containers.flush() + + service = self._make_service(flask_app_with_containers) + + with pytest.raises(ValueError, match="(Dataset not found|Pipeline not found)"): + service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id) + + def test_get_pipeline_returns_pipeline_when_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """get_pipeline returns the Pipeline when both Dataset and Pipeline exist.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + pipeline = self._create_pipeline(db_session_with_containers, tenant_id, created_by) + dataset = self._create_dataset(db_session_with_containers, tenant_id, created_by, pipeline_id=pipeline.id) + db_session_with_containers.flush() + + service = self._make_service(flask_app_with_containers) + + result = service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id) + + assert result.id == pipeline.id + + +class TestUpdateCustomizedPipelineTemplate: + """Integration tests for RagPipelineService.update_customized_pipeline_template.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_template( + self, db_session: Session, tenant_id: str, created_by: str, name: str = "Template" + ) -> PipelineCustomizedTemplate: + template = PipelineCustomizedTemplate( + tenant_id=tenant_id, + name=name, + description="Original description", + chunk_structure="fixed_size", + icon={"type": "emoji", "value": "📄"}, + position=1, + yaml_content="{}", + install_count=0, + language="en-US", + created_by=created_by, + ) + db_session.add(template) + db_session.flush() + return template + + def test_update_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None: + """update_customized_pipeline_template updates name and description.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template = self._create_template(db_session_with_containers, tenant_id, created_by) + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="Updated Name", + description="Updated description", + icon_info=IconInfo(icon="🔥"), + ) + result = RagPipelineService.update_customized_pipeline_template(template.id, info) + + assert result.name == "Updated Name" + assert result.description == "Updated description" + + def test_update_template_raises_when_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """update_customized_pipeline_template raises ValueError when template doesn't exist.""" + fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4())) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="New Name", + description="desc", + icon_info=IconInfo(icon="📄"), + ) + with pytest.raises(ValueError, match="Customized pipeline template not found"): + RagPipelineService.update_customized_pipeline_template(str(uuid4()), info) + + def test_update_template_raises_on_duplicate_name( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """update_customized_pipeline_template raises ValueError when new name already exists.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template1 = self._create_template(db_session_with_containers, tenant_id, created_by, name="Original") + self._create_template(db_session_with_containers, tenant_id, created_by, name="Duplicate") + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + info = PipelineTemplateInfoEntity( + name="Duplicate", + description="desc", + icon_info=IconInfo(icon="📄"), + ) + with pytest.raises(ValueError, match="Template name is already exists"): + RagPipelineService.update_customized_pipeline_template(template1.id, info) + + +class TestDeleteCustomizedPipelineTemplate: + """Integration tests for RagPipelineService.delete_customized_pipeline_template.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_template(self, db_session: Session, tenant_id: str, created_by: str) -> PipelineCustomizedTemplate: + template = PipelineCustomizedTemplate( + tenant_id=tenant_id, + name=f"Template {uuid4()}", + description="Description", + chunk_structure="fixed_size", + icon={"type": "emoji", "value": "📄"}, + position=1, + yaml_content="{}", + install_count=0, + language="en-US", + created_by=created_by, + ) + db_session.add(template) + db_session.flush() + return template + + def test_delete_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None: + """delete_customized_pipeline_template removes the template from the DB.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + template = self._create_template(db_session_with_containers, tenant_id, created_by) + template_id = template.id + db_session_with_containers.flush() + + fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + RagPipelineService.delete_customized_pipeline_template(template_id) + + # Verify the record is deleted within the same context + from sqlalchemy import select + + from extensions.ext_database import db as ext_db + + remaining = ext_db.session.scalar( + select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id) + ) + assert remaining is None + + def test_delete_template_raises_when_not_found( + self, db_session_with_containers: Session, flask_app_with_containers + ) -> None: + """delete_customized_pipeline_template raises ValueError when template doesn't exist.""" + fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4())) + + with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): + with pytest.raises(ValueError, match="Customized pipeline template not found"): + RagPipelineService.delete_customized_pipeline_template(str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py index 2b842629a7..724dd19f92 100644 --- a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py @@ -3,6 +3,8 @@ from __future__ import annotations from unittest.mock import patch from uuid import uuid4 +from sqlalchemy.orm import Session + from models.model import App, RecommendedApp, Site from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval from services.recommend_app.recommend_app_type import RecommendAppType @@ -91,7 +93,7 @@ class TestDatabaseRecommendAppRetrieval: class TestFetchRecommendedAppsFromDb: - def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers): + def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) @@ -111,7 +113,9 @@ class TestFetchRecommendedAppsFromDb: assert "assistant" in result["categories"] assert "writing" in result["categories"] - def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers): + def test_falls_back_to_default_language_when_empty( + self, flask_app_with_containers, db_session_with_containers: Session + ): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) @@ -124,7 +128,7 @@ class TestFetchRecommendedAppsFromDb: app_ids = {r["app_id"] for r in result["recommended_apps"]} assert app1.id in app_ids - def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers): + def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) _create_site(db_session_with_containers, app_id=app1.id) @@ -137,7 +141,7 @@ class TestFetchRecommendedAppsFromDb: app_ids = {r["app_id"] for r in result["recommended_apps"]} assert app1.id not in app_ids - def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers): + def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_recommended_app(db_session_with_containers, app_id=app1.id) @@ -151,12 +155,12 @@ class TestFetchRecommendedAppsFromDb: class TestFetchRecommendedAppDetailFromDb: - def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers): + def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers: Session): result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4())) assert result is None - def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers): + def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) _create_recommended_app(db_session_with_containers, app_id=app1.id) @@ -168,7 +172,7 @@ class TestFetchRecommendedAppDetailFromDb: assert result is None @patch("services.recommend_app.database.database_retrieval.AppDslService") - def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers): + def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index cc9596d15f..9a53ff087c 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -9,7 +9,7 @@ from werkzeug.exceptions import Unauthorized from configs import dify_config from controllers.console.error import AccountNotFound, NotAllowedCreateWorkspace -from models import AccountStatus, TenantAccountJoin +from models import AccountStatus, TenantAccountJoin, TenantStatus from services.account_service import AccountService, RegisterService, TenantService, TokenPair from services.errors.account import ( AccountAlreadyInTenantError, @@ -2851,7 +2851,7 @@ class TestRegisterService: interface_language="en-US", password=existing_pending_member_password, ) - existing_account.status = "pending" + existing_account.status = AccountStatus.PENDING db_session_with_containers.commit() @@ -2941,7 +2941,7 @@ class TestRegisterService: interface_language="en-US", password=already_in_tenant_password, ) - existing_account.status = "active" + existing_account.status = AccountStatus.ACTIVE db_session_with_containers.commit() @@ -3331,7 +3331,7 @@ class TestRegisterService: TenantService.create_tenant_member(tenant, account, role="normal") # Change tenant status to non-normal - tenant.status = "archive" + tenant.status = TenantStatus.ARCHIVE db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py index 3ec265d009..f78037e503 100644 --- a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py +++ b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py @@ -2,6 +2,7 @@ import copy import pytest from faker import Faker +from sqlalchemy.orm import Session from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, @@ -29,7 +30,9 @@ class TestAdvancedPromptTemplateService: # for consistency with other test files return {} - def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_baichuan_model_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful prompt generation for Baichuan model. @@ -64,7 +67,9 @@ class TestAdvancedPromptTemplateService: assert "{{#histories#}}" in prompt_text assert "{{#query#}}" in prompt_text - def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_common_model_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful prompt generation for common models. @@ -100,7 +105,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_prompt_case_insensitive_baichuan_detection( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan model detection is case insensitive. @@ -131,7 +136,7 @@ class TestAdvancedPromptTemplateService: assert BAICHUAN_CONTEXT in prompt_text def test_get_common_prompt_chat_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for chat app with completion mode. @@ -161,7 +166,9 @@ class TestAdvancedPromptTemplateService: assert "{{#histories#}}" in prompt_text assert "{{#query#}}" in prompt_text - def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_common_prompt_chat_app_chat_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test common prompt generation for chat app with chat mode. @@ -189,7 +196,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_common_prompt_completion_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for completion app with completion mode. @@ -217,7 +224,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_common_prompt_completion_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for completion app with chat mode. @@ -245,7 +252,9 @@ class TestAdvancedPromptTemplateService: assert CONTEXT in prompt_text assert "{{#pre_prompt#}}" in prompt_text - def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_common_prompt_no_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test common prompt generation without context. @@ -273,7 +282,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_common_prompt_unsupported_app_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation with unsupported app mode. @@ -291,7 +300,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_common_prompt_unsupported_model_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation with unsupported model mode. @@ -308,7 +317,9 @@ class TestAdvancedPromptTemplateService: # Assert: Verify empty dict is returned assert result == {} - def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_completion_prompt_with_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test completion prompt generation with context. @@ -339,7 +350,7 @@ class TestAdvancedPromptTemplateService: assert result_text == CONTEXT + original_text def test_get_completion_prompt_without_context( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test completion prompt generation without context. @@ -368,7 +379,9 @@ class TestAdvancedPromptTemplateService: assert result_text == original_text assert CONTEXT not in result_text - def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_chat_prompt_with_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test chat prompt generation with context. @@ -399,7 +412,9 @@ class TestAdvancedPromptTemplateService: assert original_text in result_text assert result_text == CONTEXT + original_text - def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_chat_prompt_without_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test chat prompt generation without context. @@ -429,7 +444,7 @@ class TestAdvancedPromptTemplateService: assert CONTEXT not in result_text def test_get_baichuan_prompt_chat_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for chat app with completion mode. @@ -460,7 +475,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_baichuan_prompt_chat_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for chat app with chat mode. @@ -489,7 +504,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_baichuan_prompt_completion_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for completion app with completion mode. @@ -517,7 +532,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_baichuan_prompt_completion_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for completion app with chat mode. @@ -545,7 +560,9 @@ class TestAdvancedPromptTemplateService: assert BAICHUAN_CONTEXT in prompt_text assert "{{#pre_prompt#}}" in prompt_text - def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_baichuan_prompt_no_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test Baichuan prompt generation without context. @@ -573,7 +590,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_baichuan_prompt_unsupported_app_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation with unsupported app mode. @@ -591,7 +608,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_baichuan_prompt_unsupported_model_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation with unsupported model mode. @@ -609,7 +626,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_prompt_all_app_modes_common_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prompt generation for all app modes with common model. @@ -641,7 +658,7 @@ class TestAdvancedPromptTemplateService: assert result != {} def test_get_prompt_all_app_modes_baichuan_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prompt generation for all app modes with Baichuan model. @@ -672,7 +689,7 @@ class TestAdvancedPromptTemplateService: assert result is not None assert result != {} - def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_edge_cases(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test prompt generation with edge cases. @@ -704,7 +721,7 @@ class TestAdvancedPromptTemplateService: # Should either return a valid result or empty dict, but not crash assert result is not None - def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies): + def test_template_immutability(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test that original templates are not modified. @@ -738,7 +755,9 @@ class TestAdvancedPromptTemplateService: assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG - def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies): + def test_baichuan_template_immutability( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that original Baichuan templates are not modified. @@ -772,7 +791,9 @@ class TestAdvancedPromptTemplateService: assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG - def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies): + def test_context_integration_consistency( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test consistency of context integration across different scenarios. @@ -828,7 +849,7 @@ class TestAdvancedPromptTemplateService: assert prompt_text.startswith(CONTEXT) def test_baichuan_context_integration_consistency( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test consistency of Baichuan context integration across different scenarios. diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 4f3c0e4200..00a2f9a59f 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -842,7 +842,6 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from graphon.file import FileTransferMethod, FileType - from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 33955d5d84..7c5d2390ba 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -1,56 +1,127 @@ +from __future__ import annotations + +import base64 import json +from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest import yaml from faker import Faker +from flask import Flask +from sqlalchemy.orm import Session -from models.model import App, AppModelConfig +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) +from extensions.ext_redis import redis_client +from graphon.enums import BuiltinNodeTypes +from models import Account, App, AppMode +from models.model import AppModelConfig, IconType +from services import app_dsl_service from services.account_service import AccountService, TenantService -from services.app_dsl_service import AppDslService, ImportMode, ImportStatus +from services.app_dsl_service import ( + CHECK_DEPENDENCIES_REDIS_KEY_PREFIX, + CURRENT_DSL_VERSION, + DSL_MAX_SIZE, + IMPORT_INFO_REDIS_EXPIRY, + IMPORT_INFO_REDIS_KEY_PREFIX, + AppDslService, + CheckDependenciesPendingData, + ImportMode, + ImportStatus, + PendingData, + _check_version_compatibility, +) from services.app_service import AppService from tests.test_containers_integration_tests.helpers import generate_valid_password +_DEFAULT_TENANT_ID = "00000000-0000-0000-0000-000000000001" +_DEFAULT_ACCOUNT_ID = "00000000-0000-0000-0000-000000000002" + + +def _account_mock(*, tenant_id: str = _DEFAULT_TENANT_ID, account_id: str = _DEFAULT_ACCOUNT_ID) -> MagicMock: + account = MagicMock(spec=Account) + account.current_tenant_id = tenant_id + account.id = account_id + return account + + +def _yaml_dump(data: dict) -> str: + return yaml.safe_dump(data, allow_unicode=True) + + +def _workflow_yaml(*, version: str = CURRENT_DSL_VERSION) -> str: + return _yaml_dump( + { + "version": version, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + } + ) + + +def _pending_yaml_content(version: str = "99.0.0") -> bytes: + return (f'version: "{version}"\nkind: app\napp:\n name: Loop Test\n mode: workflow\n').encode() + + +def _app_stub(**overrides: Any) -> App: + """Create a stub App object for testing without hitting the database.""" + defaults = { + "id": str(uuid4()), + "tenant_id": _DEFAULT_TENANT_ID, + "mode": AppMode.WORKFLOW.value, + "name": "n", + "description": "d", + "icon_type": IconType.EMOJI, + "icon": "i", + "icon_background": "#fff", + "use_icon_as_answer_icon": False, + "app_model_config": None, + } + app = MagicMock(spec=App) + for key, value in (defaults | overrides).items(): + object.__setattr__(app, key, value) + return app + class TestAppDslService: """Integration tests for AppDslService using testcontainers.""" + @pytest.fixture + def app(self, flask_app_with_containers: Flask): + return flask_app_with_containers + @pytest.fixture def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( patch("services.app_dsl_service.WorkflowService") as mock_workflow_service, patch("services.app_dsl_service.DependenciesAnalysisService") as mock_dependencies_service, - patch("services.app_dsl_service.WorkflowDraftVariableService") as mock_draft_variable_service, - patch("services.app_dsl_service.ssrf_proxy") as mock_ssrf_proxy, - patch("services.app_dsl_service.redis_client") as mock_redis_client, patch("services.app_dsl_service.app_was_created") as mock_app_was_created, - patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, ): - # Setup default mock returns mock_workflow_service.return_value.get_draft_workflow.return_value = None mock_workflow_service.return_value.sync_draft_workflow.return_value = MagicMock() mock_dependencies_service.generate_latest_dependencies.return_value = [] mock_dependencies_service.get_leaked_dependencies.return_value = [] mock_dependencies_service.generate_dependencies.return_value = [] - mock_draft_variable_service.return_value.delete_workflow_variables.return_value = None - mock_ssrf_proxy.get.return_value.content = b"test content" - mock_ssrf_proxy.get.return_value.raise_for_status.return_value = None - mock_redis_client.setex.return_value = None - mock_redis_client.get.return_value = None - mock_redis_client.delete.return_value = None mock_app_was_created.send.return_value = None - mock_app_model_config_was_updated.send.return_value = None - # Mock ModelManager for app service mock_model_instance = mock_model_manager.return_value mock_model_instance.get_default_model_instance.return_value = None - mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + mock_model_instance.get_default_provider_model_name.return_value = ( + "openai", + "gpt-3.5-turbo", + ) - # Mock FeatureService and EnterpriseService mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None @@ -58,34 +129,16 @@ class TestAppDslService: yield { "workflow_service": mock_workflow_service, "dependencies_service": mock_dependencies_service, - "draft_variable_service": mock_draft_variable_service, - "ssrf_proxy": mock_ssrf_proxy, - "redis_client": mock_redis_client, "app_was_created": mock_app_was_created, - "app_model_config_was_updated": mock_app_model_config_was_updated, "model_manager": mock_model_manager, "feature_service": mock_feature_service, "enterprise_service": mock_enterprise_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): - """ - Helper method to create a test app and account for testing. - - Args: - db_session_with_containers: Database session from testcontainers infrastructure - mock_external_service_dependencies: Mock dependencies - - Returns: - tuple: (app, account) - Created app and account instances - """ + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): fake = Faker() - - # Setup mocks for account creation with patch("services.account_service.FeatureService") as mock_account_feature_service: mock_account_feature_service.get_system_features.return_value.is_allow_register = True - - # Create account and tenant first account = AccountService.create_account( email=fake.email(), name=fake.name(), @@ -94,8 +147,6 @@ class TestAppDslService: ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant - - # Setup app creation arguments app_args = { "name": fake.company(), "description": fake.text(max_nb_chars=100), @@ -106,17 +157,11 @@ class TestAppDslService: "api_rph": 100, "api_rpm": 10, } - - # Create app app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) - return app, account - def _create_simple_yaml_content(self, app_name="Test App", app_mode="chat"): - """ - Helper method to create simple YAML content for testing. - """ + def _create_simple_yaml_content(self, app_name: str = "Test App", app_mode: str = "chat") -> str: yaml_data = { "version": "0.3.0", "kind": "app", @@ -145,88 +190,699 @@ class TestAppDslService: } return yaml.dump(yaml_data, allow_unicode=True) - def test_import_app_missing_yaml_content(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test app import with missing YAML content. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + # ── Version Compatibility ───────────────────────────────────────── - # Import app without YAML content - dsl_service = AppDslService(db_session_with_containers) - result = dsl_service.import_app( - account=account, - import_mode=ImportMode.YAML_CONTENT, - name="Missing Content App", - ) + def test_check_version_compatibility_invalid_version_returns_failed(self): + assert _check_version_compatibility("not-a-version") == ImportStatus.FAILED - # Verify import failed - assert result.status == ImportStatus.FAILED - assert result.app_id is None - assert "yaml_content is required" in result.error - assert result.imported_dsl_version == "" + def test_check_version_compatibility_newer_version_returns_pending(self): + assert _check_version_compatibility("99.0.0") == ImportStatus.PENDING - # Verify no app was created in database - apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() - assert apps_count == 1 # Only the original test app + def test_check_version_compatibility_major_older_returns_pending(self, monkeypatch): + monkeypatch.setattr(app_dsl_service, "CURRENT_DSL_VERSION", "1.0.0") + assert _check_version_compatibility("0.9.9") == ImportStatus.PENDING - def test_import_app_missing_yaml_url(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test app import with missing YAML URL. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + def test_check_version_compatibility_minor_older_returns_completed_with_warnings( + self, + ): + assert _check_version_compatibility("0.5.0") == ImportStatus.COMPLETED_WITH_WARNINGS - # Import app without YAML URL - dsl_service = AppDslService(db_session_with_containers) - result = dsl_service.import_app( - account=account, - import_mode=ImportMode.YAML_URL, - name="Missing URL App", - ) + def test_check_version_compatibility_equal_returns_completed(self): + assert _check_version_compatibility(CURRENT_DSL_VERSION) == ImportStatus.COMPLETED - # Verify import failed - assert result.status == ImportStatus.FAILED - assert result.app_id is None - assert "yaml_url is required" in result.error - assert result.imported_dsl_version == "" + # ── Import: Validation ──────────────────────────────────────────── - # Verify no app was created in database - apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() - assert apps_count == 1 # Only the original test app - - def test_import_app_invalid_import_mode(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test app import with invalid import mode. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Create YAML content - yaml_content = self._create_simple_yaml_content(fake.company(), "chat") - - # Import app with invalid mode should raise ValueError - dsl_service = AppDslService(db_session_with_containers) - with pytest.raises(ValueError, match="Invalid import_mode: invalid-mode"): - dsl_service.import_app( - account=account, + def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Invalid import_mode"): + service.import_app( + account=_account_mock(), import_mode="invalid-mode", - yaml_content=yaml_content, - name="Invalid Mode App", + yaml_content="version: '0.1.0'", ) - # Verify no app was created in database - apps_count = db_session_with_containers.query(App).where(App.tenant_id == account.current_tenant_id).count() - assert apps_count == 1 # Only the original test app + def test_import_app_missing_yaml_content(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=None, + ) + assert result.status == ImportStatus.FAILED + assert "yaml_content is required" in result.error - def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful DSL export for chat app. - """ - fake = Faker() + def test_import_app_missing_yaml_url(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url=None, + ) + assert result.status == ImportStatus.FAILED + assert "yaml_url is required" in result.error + + def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content="[]", + ) + assert result.status == ImportStatus.FAILED + assert "content must be a mapping" in result.error + + def test_import_app_version_not_str_returns_failed(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}}) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=yaml_content, + ) + assert result.status == ImportStatus.FAILED + assert "Invalid version type" in result.error + + def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump({"version": "0.6.0", "kind": "app"}), + ) + assert result.status == ImportStatus.FAILED + assert "Missing app data" in result.error + + def test_import_app_yaml_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): + def bad_safe_load(_content: str): + raise yaml.YAMLError("bad") + + monkeypatch.setattr(app_dsl_service.yaml, "safe_load", bad_safe_load) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content="x: y", + ) + assert result.status == ImportStatus.FAILED + assert result.error.startswith("Invalid YAML format:") + + def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): + monkeypatch.setattr( + AppDslService, + "_create_or_update_app", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("oops")), + ) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + ) + assert result.status == ImportStatus.FAILED + assert result.error == "oops" + + # ── Import: YAML URL ────────────────────────────────────────────── + + def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): + monkeypatch.setattr( + app_dsl_service.ssrf_proxy, + "get", + lambda _url, **_kw: (_ for _ in ()).throw(RuntimeError("boom")), + ) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/a.yml", + ) + assert result.status == ImportStatus.FAILED + assert "Error fetching YAML from URL: boom" in result.error + + def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers: Session, monkeypatch): + response = MagicMock() + response.content = b"" + response.raise_for_status.return_value = None + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/a.yml", + ) + assert result.status == ImportStatus.FAILED + assert "Empty content" in result.error + + def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers: Session, monkeypatch): + response = MagicMock() + response.content = b"x" * (DSL_MAX_SIZE + 1) + response.raise_for_status.return_value = None + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", lambda _url, **_kw: response) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url="https://example.com/a.yml", + ) + assert result.status == ImportStatus.FAILED + assert "File size exceeds" in result.error + + def test_import_app_yaml_url_user_attachments_keeps_original_url( + self, db_session_with_containers: Session, monkeypatch + ): + yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml" + yaml_bytes = _pending_yaml_content() + + requested_urls: list[str] = [] + + def fake_get(url: str, **kwargs): + requested_urls.append(url) + response = MagicMock() + response.content = yaml_bytes + response.raise_for_status.return_value = None + return response + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url=yaml_url, + ) + + assert result.status == ImportStatus.PENDING + assert result.imported_dsl_version == "99.0.0" + assert requested_urls == [yaml_url] + + def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers: Session, monkeypatch): + yaml_url = "https://github.com/acme/repo/blob/main/app.yml" + raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml" + yaml_bytes = _pending_yaml_content() + + requested_urls: list[str] = [] + + def fake_get(url: str, **kwargs): + requested_urls.append(url) + assert url == raw_url + response = MagicMock() + response.content = yaml_bytes + response.raise_for_status.return_value = None + return response + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_URL, + yaml_url=yaml_url, + ) + + assert result.status == ImportStatus.PENDING + assert requested_urls == [raw_url] + + # ── Import: App ID checks ──────────────────────────────────────── + + def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id=str(uuid4()), + ) + assert result.status == ImportStatus.FAILED + assert result.error == "App not found" + + def test_import_app_overwrite_only_allows_workflow_and_advanced_chat( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + assert app.mode == "chat" + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id=app.id, + ) + assert result.status == ImportStatus.FAILED + assert "Only workflow or advanced chat apps" in result.error + + # ── Import: Flow ────────────────────────────────────────────────── + + def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(version="99.0.0"), + name="n", + description="d", + icon_type="emoji", + icon="i", + icon_background="#000000", + ) + assert result.status == ImportStatus.PENDING + assert result.imported_dsl_version == "99.0.0" + + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{result.id}" + stored = redis_client.get(redis_key) + assert stored is not None + + def test_import_app_completed_uses_declared_dependencies( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + dependencies_payload = [ + { + "type": "package", + "value": { + "plugin_unique_identifier": "langgenius/google", + "version": "1.0.0", + }, + } + ] + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump( + { + "version": CURRENT_DSL_VERSION, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + "dependencies": dependencies_payload, + } + ), + ) + + assert result.status == ImportStatus.COMPLETED + assert result.app_id is not None + + @pytest.mark.parametrize("has_workflow", [True, False]) + def test_import_app_legacy_versions_extract_dependencies( + self, db_session_with_containers: Session, monkeypatch, has_workflow: bool + ): + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_workflow_graph", + lambda *_args, **_kwargs: ["from-workflow"], + ) + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_model_config", + lambda *_args, **_kwargs: ["from-model-config"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_latest_dependencies", + lambda deps: [SimpleNamespace(model_dump=lambda: {"dep": deps[0]})], + ) + + created_app = SimpleNamespace( + id=str(uuid4()), + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + ) + monkeypatch.setattr( + AppDslService, + "_create_or_update_app", + lambda *_args, **_kwargs: created_app, + ) + + draft_var_service = MagicMock() + monkeypatch.setattr( + app_dsl_service, + "WorkflowDraftVariableService", + lambda *args, **kwargs: draft_var_service, + ) + + data: dict = { + "version": "0.1.5", + "kind": "app", + "app": {"name": "Legacy", "mode": AppMode.WORKFLOW.value}, + } + if has_workflow: + data["workflow"] = {"graph": {"nodes": []}, "features": {}} + else: + data["model_config"] = {"model": {"provider": "openai"}} + + service = AppDslService(db_session_with_containers) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump(data), + ) + assert result.status == ImportStatus.COMPLETED_WITH_WARNINGS + draft_var_service.delete_app_workflow_variables.assert_called_once_with(app_id=created_app.id) + + # ── Confirm Import ──────────────────────────────────────────────── + + def test_confirm_import_expired_returns_failed(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=str(uuid4()), account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "expired" in result.error + + def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers: Session, monkeypatch): + import_id = str(uuid4()) + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + + pending = PendingData( + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + name="name", + description="desc", + icon_type="emoji", + icon="🤖", + icon_background="#fff", + app_id=None, + ) + redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, pending.model_dump_json()) + + created_app = SimpleNamespace( + id=str(uuid4()), + mode=AppMode.WORKFLOW.value, + tenant_id=_DEFAULT_TENANT_ID, + ) + monkeypatch.setattr( + AppDslService, + "_create_or_update_app", + lambda *_args, **_kwargs: created_app, + ) + + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=import_id, account=_account_mock()) + assert result.status == ImportStatus.COMPLETED + assert result.app_id == created_app.id + assert redis_client.get(redis_key) is None + + def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers: Session): + import_id = str(uuid4()) + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "123") + + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=import_id, account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "validation error" in result.error + + def test_confirm_import_exception_returns_failed(self, db_session_with_containers: Session): + import_id = str(uuid4()) + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "not-valid-json") + + service = AppDslService(db_session_with_containers) + result = service.confirm_import(import_id=import_id, account=_account_mock()) + assert result.status == ImportStatus.FAILED + + # ── Check Dependencies ──────────────────────────────────────────── + + def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + app_model = _app_stub() + result = service.check_dependencies(app_model=app_model) + assert result.leaked_dependencies == [] + + def test_check_dependencies_calls_analysis_service(self, db_session_with_containers: Session, monkeypatch): + app_id = str(uuid4()) + pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id) + redis_client.setex( + f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_id}", + IMPORT_INFO_REDIS_EXPIRY, + pending.model_dump_json(), + ) + + dep = app_dsl_service.PluginDependency.model_validate( + { + "type": "package", + "value": { + "plugin_unique_identifier": "acme/foo", + "version": "1.0.0", + }, + } + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [dep], + ) + + service = AppDslService(db_session_with_containers) + result = service.check_dependencies(app_model=_app_stub(id=app_id)) + assert len(result.leaked_dependencies) == 1 + + def test_check_dependencies_with_real_app( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' + redis_client.setex( + f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}", + IMPORT_INFO_REDIS_EXPIRY, + mock_dependencies_json, + ) + + dsl_service = AppDslService(db_session_with_containers) + result = dsl_service.check_dependencies(app_model=app) + assert result.leaked_dependencies == [] + + # ── Create/Update App ───────────────────────────────────────────── + + def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="loss app mode"): + service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock()) + + def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers: Session, monkeypatch): + fixed_now = object() + monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now) + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_environment_variable_from_mapping", + lambda _m: SimpleNamespace(kind="env"), + ) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_conversation_variable_from_mapping", + lambda _m: SimpleNamespace(kind="conv"), + ) + + app = _app_stub( + mode=AppMode.WORKFLOW.value, + name="old", + description="old-desc", + icon_type=IconType.EMOJI, + icon="old-icon", + icon_background="#111111", + updated_by=None, + updated_at=None, + ) + service = AppDslService(db_session_with_containers) + updated = service._create_or_update_app( + app=app, + data={ + "app": { + "mode": AppMode.WORKFLOW.value, + "name": "yaml-name", + "icon_type": IconType.IMAGE, + "icon": "X", + }, + "workflow": {"graph": {"nodes": []}, "features": {}}, + }, + account=_account_mock(), + name="override-name", + description=None, + icon_background="#222222", + ) + assert updated is app + assert app.name == "override-name" + assert app.icon_type == IconType.IMAGE + assert app.icon == "X" + assert app.icon_background == "#222222" + assert app.updated_at is fixed_now + + def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers: Session): + account = _account_mock() + account.current_tenant_id = None + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Current tenant is not set"): + service._create_or_update_app( + app=None, + data={"app": {"mode": AppMode.WORKFLOW.value, "name": "n"}}, + account=account, + ) + + def test_create_or_update_app_creates_workflow_app_and_saves_dependencies( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + + mock_wf_svc = mock_external_service_dependencies["workflow_service"] + mock_wf_svc.return_value.get_draft_workflow.return_value = MagicMock(unique_hash="uh") + + service = AppDslService(db_session_with_containers) + deps = [ + app_dsl_service.PluginDependency.model_validate( + { + "type": "package", + "value": { + "plugin_unique_identifier": "acme/foo", + "version": "1.0.0", + }, + } + ) + ] + data = { + "app": {"mode": AppMode.WORKFLOW.value, "name": "n"}, + "workflow": { + "graph": {"nodes": []}, + "features": {}, + }, + } + + app = service._create_or_update_app(app=None, data=data, account=account, dependencies=deps) + + assert app.tenant_id == account.current_tenant_id + mock_external_service_dependencies["app_was_created"].send.assert_called_once() + mock_wf_svc.return_value.sync_draft_workflow.assert_called_once() + + stored = redis_client.get(f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}") + assert stored is not None + + def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Missing workflow data"): + service._create_or_update_app( + app=_app_stub(mode=AppMode.WORKFLOW.value), + data={"app": {"mode": AppMode.WORKFLOW.value}}, + account=_account_mock(), + ) + + def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Missing model_config"): + service._create_or_update_app( + app=_app_stub(mode=AppMode.CHAT.value), + data={"app": {"mode": AppMode.CHAT.value}}, + account=_account_mock(), + ) + + def test_create_or_update_app_chat_creates_model_config_and_sends_event( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + app.app_model_config_id = None + db_session_with_containers.commit() + + service = AppDslService(db_session_with_containers) + service._create_or_update_app( + app=app, + data={ + "app": {"mode": AppMode.CHAT.value}, + "model_config": {"model": {"provider": "openai"}}, + }, + account=account, + ) + + db_session_with_containers.expire_all() + assert app.app_model_config_id is not None + + def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers: Session): + service = AppDslService(db_session_with_containers) + with pytest.raises(ValueError, match="Invalid app mode"): + service._create_or_update_app( + app=_app_stub(mode=AppMode.RAG_PIPELINE.value), + data={"app": {"mode": AppMode.RAG_PIPELINE.value}}, + account=_account_mock(), + ) + + # ── Export ───────────────────────────────────────────────────────── + + def test_export_dsl_delegates_by_mode(self, monkeypatch): + workflow_calls: list[bool] = [] + model_calls: list[bool] = [] + monkeypatch.setattr( + AppDslService, + "_append_workflow_export_data", + lambda **_kwargs: workflow_calls.append(True), + ) + monkeypatch.setattr( + AppDslService, + "_append_model_config_export_data", + lambda *_args, **_kwargs: model_calls.append(True), + ) + + workflow_app = _app_stub( + mode=AppMode.WORKFLOW.value, + icon_type="emoji", + ) + AppDslService.export_dsl(workflow_app) + assert workflow_calls == [True] + + chat_app = _app_stub( + mode=AppMode.CHAT.value, + icon_type="emoji", + app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}), + ) + AppDslService.export_dsl(chat_app) + assert model_calls == [True] + + def test_export_dsl_preserves_icon_and_icon_type(self, monkeypatch): + monkeypatch.setattr( + AppDslService, + "_append_workflow_export_data", + lambda **_kwargs: None, + ) + + emoji_app = _app_stub( + mode=AppMode.WORKFLOW.value, + name="Emoji App", + icon="🎨", + icon_type=IconType.EMOJI, + icon_background="#FF5733", + description="App with emoji icon", + use_icon_as_answer_icon=True, + ) + yaml_output = AppDslService.export_dsl(emoji_app) + data = yaml.safe_load(yaml_output) + assert data["app"]["icon"] == "🎨" + assert data["app"]["icon_type"] == "emoji" + assert data["app"]["icon_background"] == "#FF5733" + + image_app = _app_stub( + mode=AppMode.WORKFLOW.value, + name="Image App", + icon="https://example.com/icon.png", + icon_type=IconType.IMAGE, + icon_background="#FFEAD5", + description="App with image icon", + use_icon_as_answer_icon=False, + ) + yaml_output = AppDslService.export_dsl(image_app) + data = yaml.safe_load(yaml_output) + assert data["app"]["icon"] == "https://example.com/icon.png" + assert data["app"]["icon_type"] == "image" + assert data["app"]["icon_background"] == "#FFEAD5" + + def test_export_dsl_chat_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - # Create model config for the app model_config = AppModelConfig( app_id=app.id, provider="openai", @@ -247,53 +903,40 @@ class TestAppDslService: created_by=account.id, updated_by=account.id, ) - model_config.id = fake.uuid4() - - # Set the app_model_config_id to link the config + model_config.id = str(uuid4()) app.app_model_config_id = model_config.id db_session_with_containers.add(model_config) db_session_with_containers.commit() - # Export DSL exported_dsl = AppDslService.export_dsl(app, include_secret=False) - - # Parse exported YAML exported_data = yaml.safe_load(exported_dsl) - # Verify exported data structure assert exported_data["kind"] == "app" assert exported_data["app"]["name"] == app.name assert exported_data["app"]["mode"] == app.mode - assert exported_data["app"]["icon"] == app.icon - assert exported_data["app"]["icon_background"] == app.icon_background - assert exported_data["app"]["description"] == app.description - - # Verify model config was exported assert "model_config" in exported_data - # The exported model_config structure may be different from the database structure - # Check that the model config exists and has the expected content - assert exported_data["model_config"] is not None - - # Verify dependencies were exported assert "dependencies" in exported_data - assert isinstance(exported_data["dependencies"], list) - def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful DSL export for workflow app. - """ - fake = Faker() + def test_export_dsl_workflow_app_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Update app to workflow mode app.mode = "workflow" db_session_with_containers.commit() - # Mock workflow service to return a workflow mock_workflow = MagicMock() mock_workflow.to_dict.return_value = { - "graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []}, + "graph": { + "nodes": [ + { + "id": "start", + "type": "start", + "data": {"type": "start"}, + } + ], + "edges": [], + }, "features": {}, "environment_variables": [], "conversation_variables": [], @@ -302,54 +945,42 @@ class TestAppDslService: "workflow_service" ].return_value.get_draft_workflow.return_value = mock_workflow - # Export DSL exported_dsl = AppDslService.export_dsl(app, include_secret=False) - - # Parse exported YAML exported_data = yaml.safe_load(exported_dsl) - # Verify exported data structure assert exported_data["kind"] == "app" - assert exported_data["app"]["name"] == app.name assert exported_data["app"]["mode"] == "workflow" - - # Verify workflow was exported assert "workflow" in exported_data - assert "graph" in exported_data["workflow"] - assert "nodes" in exported_data["workflow"]["graph"] - - # Verify dependencies were exported assert "dependencies" in exported_data - assert isinstance(exported_data["dependencies"], list) - # Verify workflow service was called - mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( - app, None - ) - - def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful DSL export with specific workflow ID. - """ - fake = Faker() + def test_export_dsl_with_workflow_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Update app to workflow mode app.mode = "workflow" db_session_with_containers.commit() - # Mock workflow service to return a workflow when specific workflow_id is provided mock_workflow = MagicMock() mock_workflow.to_dict.return_value = { - "graph": {"nodes": [{"id": "start", "type": "start", "data": {"type": "start"}}], "edges": []}, + "graph": { + "nodes": [ + { + "id": "start", + "type": "start", + "data": {"type": "start"}, + } + ], + "edges": [], + }, "features": {}, "environment_variables": [], "conversation_variables": [], } - # Mock the get_draft_workflow method to return different workflows based on workflow_id - def mock_get_draft_workflow(app_model, workflow_id=None): - if workflow_id == "specific-workflow-id": + workflow_id = str(uuid4()) + + def mock_get_draft_workflow(app_model, wf_id=None): + if wf_id == workflow_id: return mock_workflow return None @@ -357,78 +988,351 @@ class TestAppDslService: "workflow_service" ].return_value.get_draft_workflow.side_effect = mock_get_draft_workflow - # Export DSL with specific workflow ID - exported_dsl = AppDslService.export_dsl(app, include_secret=False, workflow_id="specific-workflow-id") - - # Parse exported YAML + exported_dsl = AppDslService.export_dsl(app, include_secret=False, workflow_id=workflow_id) exported_data = yaml.safe_load(exported_dsl) - # Verify exported data structure assert exported_data["kind"] == "app" - assert exported_data["app"]["name"] == app.name - assert exported_data["app"]["mode"] == "workflow" - - # Verify workflow was exported assert "workflow" in exported_data - assert "graph" in exported_data["workflow"] - assert "nodes" in exported_data["workflow"]["graph"] - - # Verify dependencies were exported - assert "dependencies" in exported_data - assert isinstance(exported_data["dependencies"], list) - - # Verify workflow service was called with specific workflow ID - mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( - app, "specific-workflow-id" - ) def test_export_dsl_with_invalid_workflow_id_raises_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): - """ - Test that export_dsl raises error when invalid workflow ID is provided. - """ - fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - - # Update app to workflow mode app.mode = "workflow" db_session_with_containers.commit() - # Mock workflow service to return None when invalid workflow ID is provided mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.return_value = None - # Export DSL with invalid workflow ID should raise ValueError - with pytest.raises(ValueError, match="Missing draft workflow configuration, please check."): - AppDslService.export_dsl(app, include_secret=False, workflow_id="invalid-workflow-id") + with pytest.raises( + ValueError, + match="Missing draft workflow configuration, please check.", + ): + AppDslService.export_dsl(app, include_secret=False, workflow_id=str(uuid4())) - # Verify workflow service was called with the invalid workflow ID - mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once_with( - app, "invalid-workflow-id" + # ── Workflow Export Data ─────────────────────────────────────────── + + def test_append_workflow_export_data_filters_and_overrides(self, monkeypatch): + workflow_dict = { + "graph": { + "nodes": [ + { + "data": { + "type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, + "dataset_ids": ["d1", "d2"], + } + }, + { + "data": { + "type": BuiltinNodeTypes.TOOL, + "credential_id": "secret", + } + }, + { + "data": { + "type": BuiltinNodeTypes.AGENT, + "agent_parameters": {"tools": {"value": [{"credential_id": "secret"}]}}, + } + }, + { + "data": { + "type": TRIGGER_SCHEDULE_NODE_TYPE, + "config": {"x": 1}, + } + }, + { + "data": { + "type": TRIGGER_WEBHOOK_NODE_TYPE, + "webhook_url": "x", + "webhook_debug_url": "y", + } + }, + { + "data": { + "type": TRIGGER_PLUGIN_NODE_TYPE, + "subscription_id": "s", + } + }, + ] + } + } + + workflow = SimpleNamespace(to_dict=lambda *, include_secret: workflow_dict) + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + monkeypatch.setattr( + AppDslService, + "encrypt_dataset_id", + lambda *, dataset_id, tenant_id: f"enc:{tenant_id}:{dataset_id}", + ) + monkeypatch.setattr( + app_dsl_service.TriggerScheduleNode, + "get_default_config", + lambda: {"config": {"default": True}}, + ) + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_workflow", + lambda *_args, **_kwargs: ["dep-1"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace( + model_dump=lambda: { + "tenant": tenant_id, + "dep": dependencies[0], + } + ) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + export_data: dict = {} + AppDslService._append_workflow_export_data( + export_data=export_data, + app_model=_app_stub(), + include_secret=False, + workflow_id=None, ) - def test_check_dependencies_success(self, db_session_with_containers, mock_external_service_dependencies): - """ - Test successful dependency checking. - """ - fake = Faker() - app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) + nodes = export_data["workflow"]["graph"]["nodes"] + assert nodes[0]["data"]["dataset_ids"] == [ + f"enc:{_DEFAULT_TENANT_ID}:d1", + f"enc:{_DEFAULT_TENANT_ID}:d2", + ] + assert "credential_id" not in nodes[1]["data"] + assert "credential_id" not in nodes[2]["data"]["agent_parameters"]["tools"]["value"][0] + assert nodes[3]["data"]["config"] == {"default": True} + assert nodes[4]["data"]["webhook_url"] == "" + assert nodes[4]["data"]["webhook_debug_url"] == "" + assert nodes[5]["data"]["subscription_id"] == "" + assert export_data["dependencies"] == [{"tenant": _DEFAULT_TENANT_ID, "dep": "dep-1"}] - # Mock Redis to return dependencies - mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' - mock_external_service_dependencies["redis_client"].get.return_value = mock_dependencies_json + def test_append_workflow_export_data_missing_workflow_raises(self, monkeypatch): + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) - # Check dependencies - dsl_service = AppDslService(db_session_with_containers) - result = dsl_service.check_dependencies(app_model=app) + with pytest.raises(ValueError, match="Missing draft workflow configuration"): + AppDslService._append_workflow_export_data( + export_data={}, + app_model=_app_stub(), + include_secret=False, + workflow_id=None, + ) - # Verify result - assert result.leaked_dependencies == [] + # ── Model Config Export Data ────────────────────────────────────── - # Verify Redis was queried - mock_external_service_dependencies["redis_client"].get.assert_called_once_with( - f"app_check_dependencies:{app.id}" + def test_append_model_config_export_data_filters_credential_id(self, monkeypatch): + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_model_config", + lambda *_args, **_kwargs: ["dep-1"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace( + model_dump=lambda: { + "tenant": tenant_id, + "dep": dependencies[0], + } + ) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + app_model_config = SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": [{"credential_id": "secret"}]}}) + app_model = _app_stub(app_model_config=app_model_config) + export_data: dict = {} + + AppDslService._append_model_config_export_data(export_data, app_model) + assert export_data["model_config"]["agent_mode"]["tools"] == [{}] + assert export_data["dependencies"] == [{"tenant": _DEFAULT_TENANT_ID, "dep": "dep-1"}] + + def test_append_model_config_export_data_requires_app_config(self): + with pytest.raises(ValueError, match="Missing app configuration"): + AppDslService._append_model_config_export_data({}, _app_stub(app_model_config=None)) + + # ── Dependency Extraction ───────────────────────────────────────── + + def test_extract_dependencies_from_workflow_graph_covers_all_node_types(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", ) - # Verify dependencies service was called - mock_external_service_dependencies["dependencies_service"].get_leaked_dependencies.assert_called_once() + monkeypatch.setattr( + app_dsl_service.ToolNodeData, + "model_validate", + lambda _d: SimpleNamespace(provider_id="p1"), + ) + monkeypatch.setattr( + app_dsl_service.LLMNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m1")), + ) + monkeypatch.setattr( + app_dsl_service.QuestionClassifierNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m2")), + ) + monkeypatch.setattr( + app_dsl_service.ParameterExtractorNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m3")), + ) + + def kr_validate(_d): + return SimpleNamespace( + retrieval_mode="multiple", + multiple_retrieval_config=SimpleNamespace( + reranking_mode="weighted_score", + weights=SimpleNamespace(vector_setting=SimpleNamespace(embedding_provider_name="m4")), + reranking_model=None, + ), + single_retrieval_config=None, + ) + + monkeypatch.setattr( + app_dsl_service.KnowledgeRetrievalNodeData, + "model_validate", + kr_validate, + ) + + graph = { + "nodes": [ + {"data": {"type": BuiltinNodeTypes.TOOL}}, + {"data": {"type": BuiltinNodeTypes.LLM}}, + {"data": {"type": BuiltinNodeTypes.QUESTION_CLASSIFIER}}, + {"data": {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR}}, + {"data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}, + {"data": {"type": "unknown"}}, + ] + } + + deps = AppDslService._extract_dependencies_from_workflow_graph(graph) + assert deps == [ + "tool:p1", + "model:m1", + "model:m2", + "model:m3", + "model:m4", + ] + + def test_extract_dependencies_from_workflow_graph_handles_exceptions(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.ToolNodeData, + "model_validate", + lambda _d: (_ for _ in ()).throw(ValueError("bad")), + ) + deps = AppDslService._extract_dependencies_from_workflow_graph( + {"nodes": [{"data": {"type": BuiltinNodeTypes.TOOL}}]} + ) + assert deps == [] + + def test_extract_dependencies_from_model_config_parses_providers(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + + deps = AppDslService._extract_dependencies_from_model_config( + { + "model": {"provider": "p1"}, + "dataset_configs": { + "datasets": {"datasets": [{"reranking_model": {"reranking_provider_name": {"provider": "p2"}}}]} + }, + "agent_mode": {"tools": [{"provider_id": "t1"}]}, + } + ) + assert deps == ["model:p1", "model:p2", "tool:t1"] + + def test_extract_dependencies_from_model_config_handles_exceptions(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda _p: (_ for _ in ()).throw(ValueError("bad")), + ) + deps = AppDslService._extract_dependencies_from_model_config({"model": {"provider": "p1"}}) + assert deps == [] + + # ── Leaked Dependencies ─────────────────────────────────────────── + + def test_get_leaked_dependencies_empty_returns_empty(self): + assert AppDslService.get_leaked_dependencies(_DEFAULT_TENANT_ID, []) == [] + + def test_get_leaked_dependencies_delegates(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [SimpleNamespace(tenant_id=tenant_id, deps=dependencies)], + ) + res = AppDslService.get_leaked_dependencies(_DEFAULT_TENANT_ID, [SimpleNamespace(id="x")]) + assert len(res) == 1 + + # ── Encryption/Decryption ───────────────────────────────────────── + + def test_encrypt_decrypt_dataset_id_respects_config(self, monkeypatch): + tenant_id = _DEFAULT_TENANT_ID + dataset_uuid = "00000000-0000-0000-0000-000000000000" + + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + False, + ) + assert AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) == dataset_uuid + + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + True, + ) + encrypted = AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) + assert encrypted != dataset_uuid + assert base64.b64decode(encrypted.encode()) + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id=tenant_id) == dataset_uuid + + def test_decrypt_dataset_id_returns_plain_uuid_unchanged(self): + value = "00000000-0000-0000-0000-000000000000" + assert AppDslService.decrypt_dataset_id(encrypted_data=value, tenant_id=_DEFAULT_TENANT_ID) == value + + def test_decrypt_dataset_id_returns_none_on_invalid_data(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + True, + ) + assert AppDslService.decrypt_dataset_id(encrypted_data="not-base64", tenant_id=_DEFAULT_TENANT_ID) is None + + def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(self, monkeypatch): + monkeypatch.setattr( + app_dsl_service.dify_config, + "DSL_EXPORT_ENCRYPT_DATASET_ID", + True, + ) + encrypted = AppDslService.encrypt_dataset_id(dataset_id="not-a-uuid", tenant_id=_DEFAULT_TENANT_ID) + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id=_DEFAULT_TENANT_ID) is None + + # ── Utility ─────────────────────────────────────────────────────── + + def test_is_valid_uuid_handles_bad_inputs(self): + assert AppDslService._is_valid_uuid("00000000-0000-0000-0000-000000000000") is True + assert AppDslService._is_valid_uuid("nope") is False diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 5b1a4790f5..e2fe6c8476 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -7,6 +7,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom +from models import App from models.model import EndUser from models.workflow import Workflow from services.app_generate_service import AppGenerateService @@ -36,12 +37,19 @@ class TestAppGenerateService: ) as mock_message_based_generator, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config, + patch("services.quota_service.dify_config", autospec=True) as mock_quota_dify_config, patch("configs.dify_config", autospec=True) as mock_global_dify_config, ): # Setup default mock returns for billing service - mock_billing_service.update_tenant_feature_plan_usage.return_value = { - "result": "success", - "history_id": "test_history_id", + mock_billing_service.quota_reserve.return_value = { + "reservation_id": "test-reservation-id", + "available": 100, + "reserved": 1, + } + mock_billing_service.quota_commit.return_value = { + "available": 99, + "reserved": 0, + "refunded": 0, } # Setup default mock returns for workflow service @@ -101,6 +109,8 @@ class TestAppGenerateService: mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100 mock_dify_config.APP_DAILY_RATE_LIMIT = 1000 + mock_quota_dify_config.BILLING_ENABLED = False + mock_global_dify_config.BILLING_ENABLED = False mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100 mock_global_dify_config.APP_DAILY_RATE_LIMIT = 1000 @@ -118,6 +128,7 @@ class TestAppGenerateService: "message_based_generator": mock_message_based_generator, "account_feature_service": mock_account_feature_service, "dify_config": mock_dify_config, + "quota_dify_config": mock_quota_dify_config, "global_dify_config": mock_global_dify_config, } @@ -174,7 +185,7 @@ class TestAppGenerateService: return app, account - def _create_test_workflow(self, db_session_with_containers: Session, app): + def _create_test_workflow(self, db_session_with_containers: Session, app: App): """ Helper method to create a test workflow for testing. @@ -465,6 +476,7 @@ class TestAppGenerateService: # Set BILLING_ENABLED to True for this test mock_external_service_dependencies["dify_config"].BILLING_ENABLED = True + mock_external_service_dependencies["quota_dify_config"].BILLING_ENABLED = True mock_external_service_dependencies["global_dify_config"].BILLING_ENABLED = True # Setup test arguments @@ -478,8 +490,10 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - # Verify billing service was called to consume quota - mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() + # Verify billing two-phase quota (reserve + commit) + billing = mock_external_service_dependencies["billing_service"] + billing.quota_reserve.assert_called_once() + billing.quota_commit.assert_called_once() def test_generate_with_invalid_app_mode( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index fa57dd4a6f..b695ae9fd9 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -658,15 +658,17 @@ class TestAppService: # Update app icon new_icon = "🌟" new_icon_background = "#FFD93D" + new_icon_type = "image" mock_current_user = create_autospec(Account, instance=True) mock_current_user.id = account.id mock_current_user.current_tenant_id = account.current_tenant_id with patch("services.app_service.current_user", mock_current_user): - updated_app = app_service.update_app_icon(app, new_icon, new_icon_background) + updated_app = app_service.update_app_icon(app, new_icon, new_icon_background, new_icon_type) assert updated_app.icon == new_icon assert updated_app.icon_background == new_icon_background + assert str(updated_app.icon_type).lower() == new_icon_type assert updated_app.updated_by == account.id # Verify other fields remain unchanged diff --git a/api/tests/test_containers_integration_tests/services/test_attachment_service.py b/api/tests/test_containers_integration_tests/services/test_attachment_service.py index 768a8baee2..d0c07f0de8 100644 --- a/api/tests/test_containers_integration_tests/services/test_attachment_service.py +++ b/api/tests/test_containers_integration_tests/services/test_attachment_service.py @@ -7,7 +7,7 @@ from uuid import uuid4 import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound import services.attachment_service as attachment_service_module @@ -19,7 +19,7 @@ from services.attachment_service import AttachmentService class TestAttachmentService: - def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile: + def _create_upload_file(self, db_session_with_containers: Session, *, tenant_id: str | None = None) -> UploadFile: upload_file = UploadFile( tenant_id=tenant_id or str(uuid4()), storage_type=StorageType.OPENDAL, @@ -60,7 +60,7 @@ class TestAttachmentService: with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): AttachmentService(session_factory=invalid_session_factory) - def test_should_return_base64_when_file_exists(self, db_session_with_containers): + def test_should_return_base64_when_file_exists(self, db_session_with_containers: Session): upload_file = self._create_upload_file(db_session_with_containers) service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) @@ -70,7 +70,7 @@ class TestAttachmentService: assert result == base64.b64encode(b"binary-content").decode() mock_load.assert_called_once_with(upload_file.key) - def test_should_raise_not_found_when_file_missing(self, db_session_with_containers): + def test_should_raise_not_found_when_file_missing(self, db_session_with_containers: Session): service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) with patch.object(attachment_service_module.storage, "load_once") as mock_load: diff --git a/api/tests/test_containers_integration_tests/services/test_audio_service_db.py b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py new file mode 100644 index 0000000000..2593b53fe8 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py @@ -0,0 +1,211 @@ +""" +Integration tests for AudioService.transcript_tts message-ID path. + +Migrated from unit_tests/services/test_audio_service.py, replacing +db.session.get mock patches with real Message rows persisted in PostgreSQL. + +Covers: +- transcript_tts with valid message_id that resolves to a real Message +- transcript_tts returns None for invalid (non-UUID) message_id +- transcript_tts returns None when message_id is a valid UUID but no row exists +- transcript_tts returns None when message exists but has an empty answer +""" + +from collections.abc import Generator +from decimal import Decimal +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import InvokeFrom +from models.account import TenantAccountJoin +from models.enums import ConversationFromSource, MessageStatus +from models.model import App, AppMode, Conversation, Message +from services.audio_service import AudioService +from tests.test_containers_integration_tests.controllers.console.helpers import ( + create_console_account_and_tenant, + create_console_app, +) + + +def _create_conversation(db_session: Session, app: App, account_id: str) -> Conversation: + """Create a Conversation row via flush() so the rollback-based teardown can remove it.""" + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=InvokeFrom.WEB_APP.value, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=None, + from_account_id=account_id, + dialogue_count=0, + is_deleted=False, + ) + db_session.add(conversation) + db_session.flush() + return conversation + + +def _create_message( + db_session: Session, + app: App, + conversation: Conversation, + account_id: str, + *, + answer: str = "Message answer text", + status: MessageStatus | str = MessageStatus.NORMAL, +) -> Message: + """Create a Message row via flush() so the rollback-based teardown can remove it.""" + message = Message( + app_id=app.id, + model_provider=None, + model_id="", + override_model_configs=None, + conversation_id=conversation.id, + inputs={}, + query="Test query", + message={"messages": [{"role": "user", "content": "Test query"}]}, + message_tokens=0, + message_unit_price=Decimal(0), + message_price_unit=Decimal("0.001"), + answer=answer, + answer_tokens=0, + answer_unit_price=Decimal(0), + answer_price_unit=Decimal("0.001"), + parent_message_id=None, + provider_response_latency=0, + total_price=Decimal(0), + currency="USD", + status=status, + invoke_from=InvokeFrom.WEB_APP.value, + from_source=ConversationFromSource.CONSOLE, + from_end_user_id=None, + from_account_id=account_id, + ) + db_session.add(message) + db_session.flush() + return message + + +class TestAudioServiceTranscriptTTSMessageLookup: + """Integration tests for AudioService.transcript_tts message-ID lookup via real DB.""" + + @pytest.fixture(autouse=True) + def _setup_cleanup(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Track rows created by shared helpers that commit, then clean up after the test. + + The shared console helpers (create_console_account_and_tenant, create_console_app) + commit their inserts so the rows survive a simple rollback. This fixture records + the app/account/tenant created per test and explicitly deletes them after the test + so the DB does not accumulate state across tests. Conversation/Message rows are + created via flush() only, so the trailing rollback removes them. + """ + self._committed_rows: list = [] + yield + db_session_with_containers.rollback() + for entity in reversed(self._committed_rows): + db_session_with_containers.execute(delete(type(entity)).where(type(entity).id == entity.id)) + db_session_with_containers.commit() + + def _setup_app_and_account(self, db_session: Session) -> tuple[App, str, str]: + """Create committed app/account/tenant using shared helpers and track them for cleanup.""" + account, tenant = create_console_account_and_tenant(db_session) + app = create_console_app(db_session, tenant_id=tenant.id, account_id=account.id, mode=AppMode.CHAT) + + # Track rows in the order they must be deleted (FK-safe: app and join before account/tenant) + self._committed_rows.append(app) + join = db_session.scalar( + select(TenantAccountJoin).where( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant.id, + ) + ) + if join is not None: + self._committed_rows.append(join) + self._committed_rows.extend([account, tenant]) + return app, account.id, tenant.id + + def test_transcript_tts_with_message_id_success(self, db_session_with_containers: Session) -> None: + """transcript_tts invokes TTS with the message answer when message_id resolves to a real row.""" + app, account_id, _ = self._setup_app_and_account(db_session_with_containers) + conversation = _create_conversation(db_session_with_containers, app, account_id) + message = _create_message( + db_session_with_containers, + app, + conversation, + account_id, + answer="Hello from message", + ) + + mock_model_instance = MagicMock() + mock_model_instance.invoke_tts.return_value = b"audio from message" + mock_model_manager = MagicMock() + mock_model_manager.get_default_model_instance.return_value = mock_model_instance + + with patch("services.audio_service.ModelManager.for_tenant", return_value=mock_model_manager): + result = AudioService.transcript_tts( + app_model=app, + message_id=message.id, + voice="en-US-Neural", + ) + + assert result == b"audio from message" + mock_model_instance.invoke_tts.assert_called_once_with( + content_text="Hello from message", + voice="en-US-Neural", + ) + + def test_transcript_tts_returns_none_for_invalid_message_id(self, db_session_with_containers: Session) -> None: + """transcript_tts returns None immediately when message_id is not a valid UUID.""" + app, _, _ = self._setup_app_and_account(db_session_with_containers) + + result = AudioService.transcript_tts( + app_model=app, + message_id="invalid-uuid", + ) + + assert result is None + + def test_transcript_tts_returns_none_for_nonexistent_message(self, db_session_with_containers: Session) -> None: + """transcript_tts returns None when message_id is a valid UUID but no Message row exists.""" + app, _, _ = self._setup_app_and_account(db_session_with_containers) + + result = AudioService.transcript_tts( + app_model=app, + message_id=str(uuid4()), + ) + + assert result is None + + def test_transcript_tts_returns_none_for_empty_message_answer(self, db_session_with_containers: Session) -> None: + """transcript_tts returns None when the resolved message has an empty answer.""" + app, account_id, _ = self._setup_app_and_account(db_session_with_containers) + conversation = _create_conversation(db_session_with_containers, app, account_id) + message = _create_message( + db_session_with_containers, + app, + conversation, + account_id, + answer="", + status=MessageStatus.NORMAL, + ) + + result = AudioService.transcript_tts( + app_model=app, + message_id=message.id, + ) + + assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_billing_service.py b/api/tests/test_containers_integration_tests/services/test_billing_service.py index 76708b36b1..4893126d7f 100644 --- a/api/tests/test_containers_integration_tests/services/test_billing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_billing_service.py @@ -1,9 +1,14 @@ import json +from collections.abc import Generator from unittest.mock import patch +from uuid import uuid4 import pytest +from flask import Flask +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from services.billing_service import BillingService @@ -20,7 +25,7 @@ class TestBillingServiceGetPlanBulkWithCache: """ @pytest.fixture(autouse=True) - def setup_redis_cleanup(self, flask_app_with_containers): + def setup_redis_cleanup(self, flask_app_with_containers: Flask): """Clean up Redis cache before and after each test.""" with flask_app_with_containers.app_context(): # Clean up before test @@ -52,7 +57,7 @@ class TestBillingServiceGetPlanBulkWithCache: return value return None - def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when all tenants are in cache.""" with flask_app_with_containers.app_context(): # Arrange @@ -83,7 +88,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify API was not called mock_get_plan_bulk.assert_not_called() - def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when all tenants are not in cache.""" with flask_app_with_containers.app_context(): # Arrange @@ -123,7 +128,7 @@ class TestBillingServiceGetPlanBulkWithCache: assert ttl_1 > 0 assert ttl_1 <= 600 # Should be <= 600 seconds - def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when some tenants are in cache, some are not.""" with flask_app_with_containers.app_context(): # Arrange @@ -154,7 +159,7 @@ class TestBillingServiceGetPlanBulkWithCache: cached_data_3 = json.loads(cached_3) assert cached_data_3 == missing_plan["tenant-3"] - def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers: Flask): """Test fallback to API when Redis mget fails.""" with flask_app_with_containers.app_context(): # Arrange @@ -185,7 +190,7 @@ class TestBillingServiceGetPlanBulkWithCache: assert cached_1 is not None assert cached_2 is not None - def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers: Flask): """Test fallback to API when cache contains invalid JSON.""" with flask_app_with_containers.app_context(): # Arrange @@ -237,7 +242,7 @@ class TestBillingServiceGetPlanBulkWithCache: cached_data_3 = json.loads(cached_3) assert cached_data_3 == expected_plans["tenant-3"] - def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers: Flask): """Test fallback to API when cache data doesn't match SubscriptionPlan schema.""" with flask_app_with_containers.app_context(): # Arrange @@ -270,7 +275,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify API was called for tenant-2 and tenant-3 mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"]) - def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers: Flask): """Test that pipeline failure doesn't affect return value.""" with flask_app_with_containers.app_context(): # Arrange @@ -299,7 +304,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify pipeline was attempted mock_pipeline.assert_called_once() - def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers: Flask): """Test with empty tenant_ids list.""" with flask_app_with_containers.app_context(): # Act @@ -317,7 +322,7 @@ class TestBillingServiceGetPlanBulkWithCache: # But we should check that mget was not called at all # Since we can't easily verify this without more mocking, we just verify the result - def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers: Flask): """Test that expired cache keys are treated as cache misses.""" with flask_app_with_containers.app_context(): # Arrange @@ -363,3 +368,62 @@ class TestBillingServiceGetPlanBulkWithCache: assert ttl_1_new <= 600 assert ttl_2 > 0 assert ttl_2 <= 600 + + +class TestBillingServiceIsTenantOwnerOrAdmin: + """ + Integration tests for BillingService.is_tenant_owner_or_admin. + + Verifies that non-privileged roles (EDITOR, DATASET_OPERATOR) raise ValueError + when checked against real TenantAccountJoin rows in PostgreSQL. + """ + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + yield + db_session_with_containers.rollback() + + def _create_account_with_tenant_role(self, db_session: Session, role: TenantAccountRole) -> tuple[Account, Tenant]: + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session.add(tenant) + db_session.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"billing_{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session.add(account) + db_session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session.add(join) + db_session.flush() + + # Wire up in-memory reference so current_tenant_id resolves + account._current_tenant = tenant + return account, tenant + + def test_is_tenant_owner_or_admin_editor_role_raises_error(self, db_session_with_containers: Session) -> None: + """is_tenant_owner_or_admin raises ValueError for EDITOR role.""" + account, _ = self._create_account_with_tenant_role(db_session_with_containers, TenantAccountRole.EDITOR) + + with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"): + BillingService.is_tenant_owner_or_admin(account) + + def test_is_tenant_owner_or_admin_dataset_operator_raises_error(self, db_session_with_containers: Session) -> None: + """is_tenant_owner_or_admin raises ValueError for DATASET_OPERATOR role.""" + account, _ = self._create_account_with_tenant_role( + db_session_with_containers, TenantAccountRole.DATASET_OPERATOR + ) + + with pytest.raises(ValueError, match="Only team owner or team admin can perform this action"): + BillingService.is_tenant_owner_or_admin(account) diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index 6180d98b1e..8aa10129c1 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -7,6 +7,7 @@ from uuid import uuid4 import pytest from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin @@ -170,7 +171,7 @@ class ConversationServiceIntegrationTestDataFactory: class TestConversationServicePagination: """Test conversation pagination operations.""" - def test_pagination_with_non_empty_include_ids(self, db_session_with_containers): + def test_pagination_with_non_empty_include_ids(self, db_session_with_containers: Session): """ Test that non-empty include_ids filters properly. @@ -204,7 +205,7 @@ class TestConversationServicePagination: returned_ids = {conversation.id for conversation in result.data} assert returned_ids == {conversations[0].id, conversations[1].id} - def test_pagination_with_empty_exclude_ids(self, db_session_with_containers): + def test_pagination_with_empty_exclude_ids(self, db_session_with_containers: Session): """ Test that empty exclude_ids doesn't filter. @@ -237,7 +238,7 @@ class TestConversationServicePagination: # Assert assert len(result.data) == len(conversations) - def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers): + def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers: Session): """ Test that non-empty exclude_ids filters properly. @@ -271,7 +272,7 @@ class TestConversationServicePagination: returned_ids = {conversation.id for conversation in result.data} assert returned_ids == {conversations[2].id} - def test_pagination_with_sorting_descending(self, db_session_with_containers): + def test_pagination_with_sorting_descending(self, db_session_with_containers: Session): """ Test pagination with descending sort order. @@ -316,7 +317,7 @@ class TestConversationServiceMessageCreation: within conversations. """ - def test_pagination_by_first_id_without_first_id(self, db_session_with_containers): + def test_pagination_by_first_id_without_first_id(self, db_session_with_containers: Session): """ Test message pagination without specifying first_id. @@ -354,7 +355,7 @@ class TestConversationServiceMessageCreation: assert len(result.data) == 3 # All 3 messages returned assert result.has_more is False # No more messages available (3 < limit of 10) - def test_pagination_by_first_id_with_first_id(self, db_session_with_containers): + def test_pagination_by_first_id_with_first_id(self, db_session_with_containers: Session): """ Test message pagination with first_id specified. @@ -399,7 +400,9 @@ class TestConversationServiceMessageCreation: assert len(result.data) == 2 # Only 2 messages returned after first_id assert result.has_more is False # No more messages available (2 < limit of 10) - def test_pagination_by_first_id_raises_error_when_first_message_not_found(self, db_session_with_containers): + def test_pagination_by_first_id_raises_error_when_first_message_not_found( + self, db_session_with_containers: Session + ): """ Test that FirstMessageNotExistsError is raised when first_id doesn't exist. @@ -424,7 +427,7 @@ class TestConversationServiceMessageCreation: limit=10, ) - def test_pagination_with_has_more_flag(self, db_session_with_containers): + def test_pagination_with_has_more_flag(self, db_session_with_containers: Session): """ Test that has_more flag is correctly set when there are more messages. @@ -463,7 +466,7 @@ class TestConversationServiceMessageCreation: assert len(result.data) == limit # Extra message should be removed assert result.has_more is True # Flag should be set - def test_pagination_with_ascending_order(self, db_session_with_containers): + def test_pagination_with_ascending_order(self, db_session_with_containers: Session): """ Test message pagination with ascending order. @@ -512,7 +515,7 @@ class TestConversationServiceSummarization: """ @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers): + def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers: Session): """ Test successful auto-generation of conversation name. @@ -552,7 +555,7 @@ class TestConversationServiceSummarization: app_model.tenant_id, first_message.query, conversation.id, app_model.id ) - def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers): + def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers: Session): """ Test that MessageNotExistsError is raised when conversation has no messages. @@ -571,7 +574,9 @@ class TestConversationServiceSummarization: ConversationService.auto_generate_name(app_model, conversation) @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_auto_generate_name_handles_llm_failure_gracefully(self, mock_llm_generator, db_session_with_containers): + def test_auto_generate_name_handles_llm_failure_gracefully( + self, mock_llm_generator, db_session_with_containers: Session + ): """ Test that LLM generation failures are suppressed and don't crash. @@ -604,7 +609,7 @@ class TestConversationServiceSummarization: assert conversation.name == original_name # Name remains unchanged @patch("services.conversation_service.naive_utc_now") - def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers): + def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers: Session): """ Test renaming conversation with manual name. @@ -637,6 +642,40 @@ class TestConversationServiceSummarization: assert conversation.name == new_name assert conversation.updated_at == mock_time + @patch("services.conversation_service.LLMGenerator.generate_conversation_name") + def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers: Session): + """ + Test rename delegates to auto_generate_name when auto_generate is True. + + When auto_generate is True, the service should call auto_generate_name + which uses an LLM to create a descriptive conversation title. + """ + # Arrange + app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( + db_session_with_containers + ) + conversation = ConversationServiceIntegrationTestDataFactory.create_conversation( + db_session_with_containers, app_model, user + ) + ConversationServiceIntegrationTestDataFactory.create_message( + db_session_with_containers, app_model, conversation, user + ) + generated_name = "Auto Generated Name" + mock_llm_generator.return_value = generated_name + + # Act + result = ConversationService.rename( + app_model=app_model, + conversation_id=conversation.id, + user=user, + name=None, + auto_generate=True, + ) + + # Assert + assert result == conversation + assert conversation.name == generated_name + class TestConversationServiceMessageAnnotation: """ @@ -648,7 +687,9 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_create_annotation_from_message(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_create_annotation_from_message( + self, mock_current_account, mock_add_task, db_session_with_containers: Session + ): """ Test creating annotation from existing message. @@ -687,7 +728,9 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_create_annotation_without_message(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_create_annotation_without_message( + self, mock_current_account, mock_add_task, db_session_with_containers: Session + ): """ Test creating standalone annotation without message. @@ -719,7 +762,7 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers: Session): """ Test updating an existing annotation. @@ -766,7 +809,7 @@ class TestConversationServiceMessageAnnotation: mock_add_task.delay.assert_not_called() @patch("services.annotation_service.current_account_with_tenant") - def test_get_annotation_list(self, mock_current_account, db_session_with_containers): + def test_get_annotation_list(self, mock_current_account, db_session_with_containers: Session): """ Test retrieving paginated annotation list. @@ -802,7 +845,7 @@ class TestConversationServiceMessageAnnotation: assert result_total == 5 @patch("services.annotation_service.current_account_with_tenant") - def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers): + def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers: Session): """ Test retrieving annotations with keyword filtering. @@ -851,7 +894,7 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers: Session): """ Test direct annotation insertion without message reference. @@ -885,7 +928,7 @@ class TestConversationServiceExport: Tests retrieving conversation data for export purposes. """ - def test_get_conversation_success(self, db_session_with_containers): + def test_get_conversation_success(self, db_session_with_containers: Session): """Test successful retrieval of conversation.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -903,7 +946,7 @@ class TestConversationServiceExport: # Assert assert result == conversation - def test_get_conversation_not_found(self, db_session_with_containers): + def test_get_conversation_not_found(self, db_session_with_containers: Session): """Test ConversationNotExistsError when conversation doesn't exist.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -915,7 +958,7 @@ class TestConversationServiceExport: ConversationService.get_conversation(app_model=app_model, conversation_id=str(uuid4()), user=user) @patch("services.annotation_service.current_account_with_tenant") - def test_export_annotation_list(self, mock_current_account, db_session_with_containers): + def test_export_annotation_list(self, mock_current_account, db_session_with_containers: Session): """Test exporting all annotations for an app.""" # Arrange app_model, account = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -943,7 +986,7 @@ class TestConversationServiceExport: # Assert assert len(result) == 10 - def test_get_message_success(self, db_session_with_containers): + def test_get_message_success(self, db_session_with_containers: Session): """Test successful retrieval of a message.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -967,7 +1010,7 @@ class TestConversationServiceExport: # Assert assert result == message - def test_get_message_not_found(self, db_session_with_containers): + def test_get_message_not_found(self, db_session_with_containers: Session): """Test MessageNotExistsError when message doesn't exist.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -978,7 +1021,7 @@ class TestConversationServiceExport: with pytest.raises(MessageNotExistsError): MessageService.get_message(app_model=app_model, user=user, message_id=str(uuid4())) - def test_get_conversation_for_end_user(self, db_session_with_containers): + def test_get_conversation_for_end_user(self, db_session_with_containers: Session): """ Test retrieving conversation created by end user via API. @@ -1004,7 +1047,7 @@ class TestConversationServiceExport: assert result == conversation @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_conversation(self, mock_delete_task, db_session_with_containers): + def test_delete_conversation(self, mock_delete_task, db_session_with_containers: Session): """ Test conversation deletion with async cleanup. @@ -1037,7 +1080,7 @@ class TestConversationServiceExport: mock_delete_task.delay.assert_called_once_with(conversation_id) @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers): + def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers: Session): """ Test deletion is denied when conversation belongs to a different account. """ @@ -1066,3 +1109,32 @@ class TestConversationServiceExport: not_deleted = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation.id)) assert not_deleted is not None mock_delete_task.delay.assert_not_called() + + @patch("services.conversation_service.delete_conversation_related_data") + def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers: Session): + """ + Test that delete propagates exceptions and does not trigger the cleanup task. + + When a DB error occurs during deletion, the service must rollback the + transaction and re-raise the exception without scheduling async cleanup. + """ + # Arrange + app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( + db_session_with_containers + ) + conversation = ConversationServiceIntegrationTestDataFactory.create_conversation( + db_session_with_containers, app_model, user + ) + conversation_id = conversation.id + + # Act — force an error during the delete to exercise the rollback path + with patch("services.conversation_service.db.session.delete", side_effect=Exception("DB error")): + with pytest.raises(Exception, match="DB error"): + ConversationService.delete(app_model=app_model, conversation_id=conversation_id, user=user) + + # Assert — async cleanup must NOT have been scheduled + mock_delete_task.delay.assert_not_called() + + # Conversation is still present because the deletion was never committed + still_there = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation_id)) + assert still_there is not None diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py new file mode 100644 index 0000000000..6c292dbc4b --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py @@ -0,0 +1,525 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from flask import Flask +from sqlalchemy.orm import Session, sessionmaker + +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from graphon.variables import FloatVariable, IntegerVariable, StringVariable +from models.account import Account, Tenant, TenantAccountJoin +from models.enums import ConversationFromSource +from models.model import App, Conversation, EndUser +from models.workflow import ConversationVariable +from services.conversation_service import ConversationService +from services.errors.conversation import ( + ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, + LastConversationNotExistsError, +) + + +class ConversationServiceVariableIntegrationFactory: + @staticmethod + def create_app_and_account(db_session_with_containers): + tenant = Tenant(name=f"Tenant {uuid4()}") + db_session_with_containers.add(tenant) + db_session_with_containers.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"conversation-variable-{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + tenant_join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + db_session_with_containers.add(tenant_join) + db_session_with_containers.flush() + + app = App( + tenant_id=tenant.id, + name=f"App {uuid4()}", + description="", + mode="chat", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + + return app, account + + @staticmethod + def create_end_user(db_session_with_containers, app: App): + end_user = EndUser( + tenant_id=app.tenant_id, + app_id=app.id, + type=InvokeFrom.SERVICE_API.value, + external_user_id=f"external-{uuid4()}", + name=f"End User {uuid4()}", + is_anonymous=False, + session_id=f"session-{uuid4()}", + ) + db_session_with_containers.add(end_user) + db_session_with_containers.commit() + return end_user + + @staticmethod + def create_conversation( + db_session_with_containers, + app: App, + user: Account | EndUser, + *, + name: str | None = None, + invoke_from: InvokeFrom = InvokeFrom.WEB_APP, + created_at: datetime | None = None, + updated_at: datetime | None = None, + ) -> Conversation: + conversation = Conversation( + app_id=app.id, + app_model_config_id=None, + model_provider=None, + model_id="", + override_model_configs=None, + mode=app.mode, + name=name or f"Conversation {uuid4()}", + summary="", + inputs={}, + introduction="", + system_instruction="", + system_instruction_tokens=0, + status="normal", + invoke_from=invoke_from.value, + from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE, + from_end_user_id=user.id if isinstance(user, EndUser) else None, + from_account_id=user.id if isinstance(user, Account) else None, + dialogue_count=0, + is_deleted=False, + ) + conversation.inputs = {} + if created_at is not None: + conversation.created_at = created_at + if updated_at is not None: + conversation.updated_at = updated_at + + db_session_with_containers.add(conversation) + db_session_with_containers.commit() + return conversation + + @staticmethod + def create_variable( + db_session_with_containers, + *, + app: App, + conversation: Conversation, + variable: StringVariable | FloatVariable | IntegerVariable, + created_at: datetime | None = None, + ) -> ConversationVariable: + row = ConversationVariable.from_variable(app_id=app.id, conversation_id=conversation.id, variable=variable) + if created_at is not None: + row.created_at = created_at + row.updated_at = created_at + + db_session_with_containers.add(row) + db_session_with_containers.commit() + return row + + +@pytest.fixture +def real_conversation_service_session_factory(flask_app_with_containers: Flask): + del flask_app_with_containers + real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + + with ( + patch("services.conversation_service.session_factory.create_session", side_effect=lambda: real_session_maker()), + patch("services.conversation_service.session_factory.get_session_maker", return_value=real_session_maker), + ): + yield + + +class TestConversationServiceVariables: + def test_get_conversational_variable_success( + self, db_session_with_containers: Session, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + older_time = datetime(2024, 1, 1, 12, 0, 0) + newer_time = older_time + timedelta(minutes=5) + + first_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + created_at=older_time, + ) + second_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="priority", value="high"), + created_at=newer_time, + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=None, + ) + + assert [item["id"] for item in result.data] == [first_variable.id, second_variable.id] + assert [item["name"] for item in result.data] == ["topic", "priority"] + assert result.limit == 10 + assert result.has_more is False + + def test_get_conversational_variable_with_last_id( + self, db_session_with_containers: Session, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + base_time = datetime(2024, 1, 1, 9, 0, 0) + + first_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + created_at=base_time, + ) + second_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="priority", value="high"), + created_at=base_time + timedelta(minutes=1), + ) + third_variable = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="owner", value="alice"), + created_at=base_time + timedelta(minutes=2), + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=first_variable.id, + ) + + assert [item["id"] for item in result.data] == [second_variable.id, third_variable.id] + assert result.has_more is False + + def test_get_conversational_variable_last_id_not_found_raises_error( + self, db_session_with_containers: Session, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=10, + last_id=str(uuid4()), + ) + + def test_get_conversational_variable_sets_has_more( + self, db_session_with_containers: Session, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + for index in range(3): + factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name=f"var_{index}", value=f"value_{index}"), + created_at=datetime(2024, 1, 1, 10, 0, index), + ) + + result = ConversationService.get_conversational_variable( + app_model=app, + conversation_id=conversation.id, + user=account, + limit=2, + last_id=None, + ) + + assert len(result.data) == 2 + assert result.has_more is True + + def test_update_conversation_variable_success( + self, db_session_with_containers: Session, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=StringVariable(id=str(uuid4()), name="topic", value="billing"), + ) + updated_at = datetime(2024, 1, 1, 15, 0, 0) + + with patch("services.conversation_service.naive_utc_now", return_value=updated_at): + result = ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value="support", + ) + + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(ConversationVariable, (existing.id, conversation.id)) + + assert persisted is not None + assert persisted.to_variable().value == "support" + assert result["id"] == existing.id + assert result["value"] == "support" + assert result["updated_at"] == updated_at + + def test_update_conversation_variable_not_found_raises_error( + self, db_session_with_containers: Session, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=str(uuid4()), + user=account, + new_value="support", + ) + + def test_update_conversation_variable_type_mismatch_raises_error( + self, db_session_with_containers: Session, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=FloatVariable(id=str(uuid4()), name="score", value=1.5), + ) + + with pytest.raises(ConversationVariableTypeMismatchError, match="expects float"): + ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value="wrong-type", + ) + + def test_update_conversation_variable_integer_number_compatibility( + self, db_session_with_containers: Session, real_conversation_service_session_factory + ): + del real_conversation_service_session_factory + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + conversation = factory.create_conversation(db_session_with_containers, app, account) + existing = factory.create_variable( + db_session_with_containers, + app=app, + conversation=conversation, + variable=IntegerVariable(id=str(uuid4()), name="attempts", value=1), + ) + + result = ConversationService.update_conversation_variable( + app_model=app, + conversation_id=conversation.id, + variable_id=existing.id, + user=account, + new_value=42, + ) + + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(ConversationVariable, (existing.id, conversation.id)) + + assert persisted is not None + assert persisted.to_variable().value == 42 + assert result["value"] == 42 + + +class TestConversationServicePaginationWithContainers: + def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers: Session): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + + with pytest.raises(LastConversationNotExistsError): + ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=str(uuid4()), + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers: Session): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + base_time = datetime(2024, 1, 1, 8, 0, 0) + newest = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Newest", + updated_at=base_time + timedelta(minutes=2), + ) + middle = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Middle", + updated_at=base_time + timedelta(minutes=1), + ) + oldest = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Oldest", + updated_at=base_time, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=middle.id, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + ) + + assert newest.id != middle.id + assert [conversation.id for conversation in result.data] == [oldest.id] + + def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers: Session): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha") + beta = factory.create_conversation(db_session_with_containers, app, account, name="Beta") + gamma = factory.create_conversation(db_session_with_containers, app, account, name="Gamma") + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=beta.id, + limit=10, + invoke_from=InvokeFrom.WEB_APP, + sort_by="name", + ) + + assert alpha.id != beta.id + assert [conversation.id for conversation in result.data] == [gamma.id] + + def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers: Session): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + end_user = factory.create_end_user(db_session_with_containers, app) + account_conversation = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Console Conversation", + invoke_from=InvokeFrom.WEB_APP, + ) + end_user_conversation = factory.create_conversation( + db_session_with_containers, + app, + end_user, + name="API Conversation", + invoke_from=InvokeFrom.SERVICE_API, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=end_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.SERVICE_API, + ) + + assert account_conversation.id != end_user_conversation.id + assert [conversation.id for conversation in result.data] == [end_user_conversation.id] + + def test_pagination_filters_to_account_console_source(self, db_session_with_containers: Session): + factory = ConversationServiceVariableIntegrationFactory + app, account = factory.create_app_and_account(db_session_with_containers) + end_user = factory.create_end_user(db_session_with_containers, app) + account_conversation = factory.create_conversation( + db_session_with_containers, + app, + account, + name="Console Conversation", + invoke_from=InvokeFrom.WEB_APP, + ) + factory.create_conversation( + db_session_with_containers, + app, + end_user, + name="API Conversation", + invoke_from=InvokeFrom.SERVICE_API, + ) + + result = ConversationService.pagination_by_last_id( + session=db_session_with_containers, + app_model=app, + user=account, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + assert [conversation.id for conversation in result.data] == [account_conversation.id] diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index fb0adbbcc2..638a962f18 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,17 +3,22 @@ from uuid import uuid4 import pytest -from graphon.variables import StringVariable -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from extensions.ext_database import db +from graphon.variables import StringVariable from models.workflow import ConversationVariable from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater class TestConversationVariableUpdater: def _create_conversation_variable( - self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None + self, + db_session_with_containers: Session, + *, + conversation_id: str, + variable: StringVariable, + app_id: str | None = None, ) -> ConversationVariable: row = ConversationVariable( id=variable.id, @@ -25,7 +30,7 @@ class TestConversationVariableUpdater: db_session_with_containers.commit() return row - def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers): + def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers: Session): conversation_id = str(uuid4()) variable = StringVariable(id=str(uuid4()), name="topic", value="old value") self._create_conversation_variable( @@ -42,7 +47,7 @@ class TestConversationVariableUpdater: assert row is not None assert row.data == updated_variable.model_dump_json() - def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers): + def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers: Session): conversation_id = str(uuid4()) variable = StringVariable(id=str(uuid4()), name="topic", value="value") updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) @@ -50,7 +55,7 @@ class TestConversationVariableUpdater: with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): updater.update(conversation_id=conversation_id, variable=variable) - def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers): + def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers: Session): updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) result = updater.flush() diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 0f63d98642..09ba041244 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -3,6 +3,7 @@ from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.errors.error import QuotaExceededError from models import TenantCreditPool @@ -14,7 +15,7 @@ class TestCreditPoolService: def _create_tenant_id(self) -> str: return str(uuid4()) - def test_create_default_pool(self, db_session_with_containers): + def test_create_default_pool(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) @@ -25,7 +26,7 @@ class TestCreditPoolService: assert pool.quota_used == 0 assert pool.quota_limit > 0 - def test_get_pool_returns_pool_when_exists(self, db_session_with_containers): + def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) @@ -35,17 +36,17 @@ class TestCreditPoolService: assert result.tenant_id == tenant_id assert result.pool_type == ProviderQuotaType.TRIAL - def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): + def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers: Session): result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) assert result is None - def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers): + def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers: Session): result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10) assert result is False - def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers): + def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) @@ -53,7 +54,7 @@ class TestCreditPoolService: assert result is True - def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers): + def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) # Exhaust credits @@ -64,11 +65,11 @@ class TestCreditPoolService: assert result is False - def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers): + def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers: Session): with pytest.raises(QuotaExceededError, match="Credit pool not found"): CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10) - def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers): + def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) pool.quota_used = pool.quota_limit @@ -77,7 +78,7 @@ class TestCreditPoolService: with pytest.raises(QuotaExceededError, match="No credits remaining"): CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10) - def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers): + def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) credits_required = 10 @@ -89,7 +90,7 @@ class TestCreditPoolService: pool = CreditPoolService.get_pool(tenant_id=tenant_id) assert pool.quota_used == credits_required - def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers): + def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) remaining = 5 diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 71c8874f79..f9898e2cfa 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -8,6 +8,7 @@ checks with testcontainers-backed infrastructure instead of database-chain mocks from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db @@ -107,7 +108,7 @@ class DatasetPermissionTestDataFactory: class TestDatasetPermissionServiceGetPartialMemberList: """Verify partial-member list reads against persisted DatasetPermission rows.""" - def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers): + def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers: Session): """ Test retrieving partial member list with multiple members. """ @@ -138,7 +139,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: assert set(result) == set(expected_account_ids) assert len(result) == 3 - def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers): + def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers: Session): """ Test retrieving partial member list with single member. """ @@ -160,7 +161,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: assert set(result) == set(expected_account_ids) assert len(result) == 1 - def test_get_dataset_partial_member_list_empty(self, db_session_with_containers): + def test_get_dataset_partial_member_list_empty(self, db_session_with_containers: Session): """ Test retrieving partial member list when no members exist. """ @@ -179,7 +180,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: class TestDatasetPermissionServiceUpdatePartialMemberList: """Verify partial-member list updates against persisted DatasetPermission rows.""" - def test_update_partial_member_list_add_new_members(self, db_session_with_containers): + def test_update_partial_member_list_add_new_members(self, db_session_with_containers: Session): """ Test adding new partial members to a dataset. """ @@ -203,7 +204,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert set(result) == {member_1.id, member_2.id} - def test_update_partial_member_list_replace_existing(self, db_session_with_containers): + def test_update_partial_member_list_replace_existing(self, db_session_with_containers: Session): """ Test replacing existing partial members with new ones. """ @@ -239,7 +240,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert set(result) == {new_member_1.id, new_member_2.id} - def test_update_partial_member_list_empty_list(self, db_session_with_containers): + def test_update_partial_member_list_empty_list(self, db_session_with_containers: Session): """ Test updating with empty member list (clearing all members). """ @@ -264,7 +265,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers): + def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers: Session): """ Test error handling and rollback on database error. """ @@ -313,7 +314,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: class TestDatasetPermissionServiceClearPartialMemberList: """Verify partial-member clearing against persisted DatasetPermission rows.""" - def test_clear_partial_member_list_success(self, db_session_with_containers): + def test_clear_partial_member_list_success(self, db_session_with_containers: Session): """ Test successful clearing of partial member list. """ @@ -338,7 +339,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_clear_partial_member_list_empty_list(self, db_session_with_containers): + def test_clear_partial_member_list_empty_list(self, db_session_with_containers: Session): """ Test clearing partial member list when no members exist. """ @@ -353,7 +354,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers): + def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers: Session): """ Test error handling and rollback on database error. """ @@ -398,7 +399,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: class TestDatasetServiceCheckDatasetPermission: """Verify dataset access checks against persisted partial-member permissions.""" - def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers): + def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers: Session): """Test that users from different tenants cannot access dataset.""" owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) @@ -410,7 +411,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError): DatasetService.check_dataset_permission(dataset, other_user) - def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers): + def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers: Session): """Test that tenant owners can access any dataset regardless of permission level.""" owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -423,7 +424,7 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, owner) - def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers): + def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers: Session): """Test ONLY_ME permission allows only the dataset creator to access.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) @@ -433,7 +434,7 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, creator) - def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers): + def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers: Session): """Test ONLY_ME permission denies access to non-creators.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -447,7 +448,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError): DatasetService.check_dataset_permission(dataset, other) - def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers): + def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers: Session): """Test ALL_TEAM permission allows any team member to access the dataset.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -460,7 +461,9 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, member) - def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers): + def test_check_dataset_permission_partial_members_with_permission_success( + self, db_session_with_containers: Session + ): """ Test that user with explicit permission can access partial_members dataset. """ @@ -485,7 +488,9 @@ class TestDatasetServiceCheckDatasetPermission: permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert user.id in permissions - def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers): + def test_check_dataset_permission_partial_members_without_permission_error( + self, db_session_with_containers: Session + ): """ Test error when user without permission tries to access partial_members dataset. """ @@ -506,7 +511,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers): + def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers: Session): """Test PARTIAL_TEAM permission allows creator to access without explicit permission.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index f9bfa570cb..e6ee896a52 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from graphon.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus @@ -712,7 +712,7 @@ class TestDatasetServiceRetrievalConfiguration: class TestDocumentServicePauseRecoverRetry: """Tests for pause/recover/retry orchestration using real DB and Redis.""" - def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"): + def _create_indexing_document(self, db_session_with_containers: Session, indexing_status="indexing"): factory = DatasetServiceIntegrationDataFactory account, tenant = factory.create_account_with_tenant(db_session_with_containers) dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) @@ -721,7 +721,7 @@ class TestDocumentServicePauseRecoverRetry: db_session_with_containers.commit() return doc, account - def test_pause_document_success(self, db_session_with_containers): + def test_pause_document_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService @@ -740,7 +740,7 @@ class TestDocumentServicePauseRecoverRetry: assert redis_client.get(cache_key) is not None redis_client.delete(cache_key) - def test_pause_document_invalid_status_error(self, db_session_with_containers): + def test_pause_document_invalid_status_error(self, db_session_with_containers: Session): from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -751,7 +751,7 @@ class TestDocumentServicePauseRecoverRetry: with pytest.raises(DocumentIndexingError): DocumentService.pause_document(doc) - def test_recover_document_success(self, db_session_with_containers): + def test_recover_document_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService @@ -775,7 +775,7 @@ class TestDocumentServicePauseRecoverRetry: assert redis_client.get(cache_key) is None recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id) - def test_retry_document_indexing_success(self, db_session_with_containers): + def test_retry_document_indexing_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py index c486ff5613..08de79f4b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py @@ -6,6 +6,7 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from models.account import Account, Tenant, TenantAccountJoin from services.dataset_service import DatasetService @@ -48,7 +49,7 @@ class TestDatasetServiceCreateRagPipelineDataset: permission="only_me", ) - def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers: Session): tenant, _ = self._create_tenant_and_account(db_session_with_containers) mock_user = Mock(id=None) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index 3cac964d89..c43a5d5978 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,6 +3,8 @@ from unittest.mock import patch from uuid import uuid4 +from sqlalchemy.orm import Session + from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -101,7 +103,7 @@ class DatasetDeleteIntegrationDataFactory: class TestDatasetServiceDeleteDataset: """Integration coverage for DatasetService.delete_dataset using testcontainers.""" - def test_delete_dataset_with_documents_success(self, db_session_with_containers): + def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session): """Delete a dataset with documents and dispatch cleanup through the real signal handler.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -144,7 +146,7 @@ class TestDatasetServiceDeleteDataset: dataset.pipeline_id, ) - def test_delete_empty_dataset_success(self, db_session_with_containers): + def test_delete_empty_dataset_success(self, db_session_with_containers: Session): """Delete an empty dataset without scheduling cleanup when both gating fields are absent.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -172,7 +174,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session): """Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -200,7 +202,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers): + def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers: Session): """Delete a dataset without cleanup when indexing exists but doc_form resolves to None.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -228,7 +230,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_not_found(self, db_session_with_containers): + def test_delete_dataset_not_found(self, db_session_with_containers: Session): """Return False without scheduling cleanup when the target dataset does not exist.""" # Arrange owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py new file mode 100644 index 0000000000..2bec703f0c --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py @@ -0,0 +1,650 @@ +"""Testcontainers integration tests for SQL-backed DocumentService paths.""" + +import datetime +import json +from unittest.mock import create_autospec, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from core.rag.index_processor.constant.index_type import IndexStructureType +from extensions.storage.storage_type import StorageType +from models import Account +from models.dataset import Dataset, Document +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus +from models.model import UploadFile +from services.dataset_service import DocumentService +from services.errors.account import NoPermissionError + +FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0) + + +class DocumentServiceIntegrationFactory: + @staticmethod + def create_dataset( + db_session_with_containers, + *, + tenant_id: str | None = None, + created_by: str | None = None, + name: str | None = None, + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id or str(uuid4()), + name=name or f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by or str(uuid4()), + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers, + *, + dataset: Dataset, + name: str = "doc.txt", + position: int = 1, + tenant_id: str | None = None, + indexing_status: str = IndexingStatus.COMPLETED, + enabled: bool = True, + archived: bool = False, + is_paused: bool = False, + need_summary: bool = False, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + batch: str | None = None, + data_source_type: str = DataSourceType.UPLOAD_FILE, + data_source_info: dict | None = None, + created_by: str | None = None, + ) -> Document: + document = Document( + tenant_id=tenant_id or dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type=data_source_type, + data_source_info=json.dumps(data_source_info or {}), + batch=batch or f"batch-{uuid4()}", + name=name, + created_from=DocumentCreatedFrom.WEB, + created_by=created_by or dataset.created_by, + doc_form=doc_form, + ) + document.indexing_status = indexing_status + document.enabled = enabled + document.archived = archived + document.is_paused = is_paused + document.need_summary = need_summary + if indexing_status == IndexingStatus.COMPLETED: + document.completed_at = FIXED_UPLOAD_CREATED_AT + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + @staticmethod + def create_upload_file( + db_session_with_containers, + *, + tenant_id: str, + created_by: str, + file_id: str | None = None, + name: str = "source.txt", + ) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.LOCAL, + key=f"uploads/{uuid4()}", + name=name, + size=128, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=created_by, + created_at=FIXED_UPLOAD_CREATED_AT, + used=False, + ) + if file_id: + upload_file.id = file_id + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + +@pytest.fixture +def current_user_mock(): + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.id = str(uuid4()) + current_user.current_tenant_id = str(uuid4()) + current_user.current_role = None + yield current_user + + +def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.get_document(dataset.id, None) is None + + +def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset) + + result = DocumentService.get_document(dataset.id, document.id) + + assert result is not None + assert result.id == document.id + + +def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + result = DocumentService.get_documents_by_ids(dataset.id, []) + + assert result == [] + + +def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + doc_a = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, name="a.txt") + doc_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + name="b.txt", + position=2, + ) + + result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id]) + + assert {document.id for document in result} == {doc_a.id, doc_b.id} + + +def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.update_documents_need_summary(dataset.id, []) == 0 + + +def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + paragraph_doc = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + need_summary=True, + ) + qa_doc = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + need_summary=True, + doc_form=IndexStructureType.QA_INDEX, + ) + + updated_count = DocumentService.update_documents_need_summary( + dataset.id, + [paragraph_doc.id, qa_doc.id], + need_summary=False, + ) + + db_session_with_containers.expire_all() + refreshed_paragraph = db_session_with_containers.get(Document, paragraph_doc.id) + refreshed_qa = db_session_with_containers.get(Document, qa_doc.id) + assert updated_count == 1 + assert refreshed_paragraph is not None + assert refreshed_qa is not None + assert refreshed_paragraph.need_summary is False + assert refreshed_qa.need_summary is True + + +def test_get_document_download_url_uses_signed_url_helper(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url: + result = DocumentService.get_document_download_url(document) + + assert result == "signed-url" + get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True) + + +def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_type=DataSourceType.WEBSITE_CRAWL, + data_source_info={"url": "https://example.com"}, + ) + + with pytest.raises(NotFound, match="invalid source"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + +def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={}, + ) + + with pytest.raises(NotFound, match="missing file"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + +def test_get_upload_file_id_for_upload_file_document_returns_string_id(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": 99}, + ) + + result = DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + assert result == "99" + + +def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": "missing-file"}, + ) + + with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}): + with pytest.raises(NotFound, match="Uploaded file not found"): + DocumentService._get_upload_file_for_upload_file_document(document) + + +def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + result = DocumentService._get_upload_file_for_upload_file_document(document) + + assert result.id == upload_file.id + + +def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + with pytest.raises(NotFound, match="Document not found"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[str(uuid4())], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + tenant_id=str(uuid4()), + data_source_info={"upload_file_id": upload_file.id}, + ) + + with pytest.raises(Forbidden, match="No permission"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document.id], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": str(uuid4())}, + ) + + with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document.id], + tenant_id=dataset.tenant_id, + ) + + +def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + mapping = DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id=dataset.id, + document_ids=[document_a.id, document_b.id], + tenant_id=dataset.tenant_id, + ) + + assert mapping[document_a.id].id == upload_file_a.id + assert mapping[document_b.id].id == upload_file_b.id + + +def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset( + current_user_mock, flask_app_with_containers +): + with flask_app_with_containers.app_context(): + with pytest.raises(NotFound, match="Dataset not found"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=str(uuid4()), + document_ids=[str(uuid4())], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + +def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden( + db_session_with_containers, + current_user_mock, +): + dataset = DocumentServiceIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=current_user_mock.current_tenant_id, + created_by=current_user_mock.id, + ) + + with patch( + "services.dataset_service.DatasetService.check_dataset_permission", + side_effect=NoPermissionError("denied"), + ): + with pytest.raises(Forbidden, match="denied"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=[], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + +def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order( + db_session_with_containers, + current_user_mock, +): + dataset = DocumentServiceIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=current_user_mock.current_tenant_id, + created_by=current_user_mock.id, + ) + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=[document_b.id, document_a.id], + tenant_id=current_user_mock.current_tenant_id, + current_user=current_user_mock, + ) + + assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id] + assert download_name.endswith(".zip") + + +def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + enabled_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + enabled=False, + ) + + result = DocumentService.get_document_by_dataset_id(dataset.id) + + assert [document.id for document in result] == [enabled_document.id] + + +def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchived_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + available_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + indexing_status=IndexingStatus.COMPLETED, + enabled=True, + archived=False, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + indexing_status=IndexingStatus.ERROR, + ) + + result = DocumentService.get_working_documents_by_dataset_id(dataset.id) + + assert [document.id for document in result] == [available_document.id] + + +def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + error_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + indexing_status=IndexingStatus.ERROR, + ) + paused_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + indexing_status=IndexingStatus.PAUSED, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=3, + indexing_status=IndexingStatus.COMPLETED, + ) + + result = DocumentService.get_error_documents_by_dataset_id(dataset.id) + + assert {document.id for document in result} == {error_document.id, paused_document.id} + + +def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + batch = f"batch-{uuid4()}" + matching_document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + batch=batch, + ) + DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + tenant_id=str(uuid4()), + batch=batch, + ) + + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.current_tenant_id = dataset.tenant_id + result = DocumentService.get_batch_documents(dataset.id, batch) + + assert [document.id for document in result] == [matching_document.id] + + +def test_get_document_file_detail_returns_upload_file(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + + result = DocumentService.get_document_file_detail(upload_file.id) + + assert result is not None + assert result.id == upload_file.id + + +def test_delete_document_emits_signal_and_commits(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + upload_file = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + ) + document = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file.id}, + ) + + with patch("services.dataset_service.document_was_deleted.send") as signal_send: + DocumentService.delete_document(document) + + assert db_session_with_containers.get(Document, document.id) is None + signal_send.assert_called_once_with( + document.id, + dataset_id=document.dataset_id, + doc_form=document.doc_form, + file_id=upload_file.id, + ) + + +def test_delete_documents_ignores_empty_input(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + with patch("services.dataset_service.batch_clean_document_task.delay") as delay: + DocumentService.delete_documents(dataset, []) + + delay.assert_not_called() + + +def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + dataset.chunk_structure = IndexStructureType.PARAGRAPH_INDEX + db_session_with_containers.commit() + upload_file_a = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="a.txt", + ) + upload_file_b = DocumentServiceIntegrationFactory.create_upload_file( + db_session_with_containers, + tenant_id=dataset.tenant_id, + created_by=dataset.created_by, + name="b.txt", + ) + document_a = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + data_source_info={"upload_file_id": upload_file_a.id}, + ) + document_b = DocumentServiceIntegrationFactory.create_document( + db_session_with_containers, + dataset=dataset, + position=2, + data_source_info={"upload_file_id": upload_file_b.id}, + ) + + with patch("services.dataset_service.batch_clean_document_task.delay") as delay: + DocumentService.delete_documents(dataset, [document_a.id, document_b.id]) + + assert db_session_with_containers.get(Document, document_a.id) is None + assert db_session_with_containers.get(Document, document_b.id) is None + delay.assert_called_once() + args = delay.call_args.args + assert args[0] == [document_a.id, document_b.id] + assert args[1] == dataset.id + assert set(args[3]) == {upload_file_a.id, upload_file_b.id} + + +def test_get_documents_position_returns_next_position_when_documents_exist(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3) + + assert DocumentService.get_documents_position(dataset.id) == 4 + + +def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers): + dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) + + assert DocumentService.get_documents_position(dataset.id) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py new file mode 100644 index 0000000000..0603a1e27f --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py @@ -0,0 +1,614 @@ +"""Testcontainers integration tests for DatasetService permission and lifecycle SQL paths.""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from flask import Flask +from sqlalchemy.orm import Session +from werkzeug.exceptions import NotFound + +from core.rag.index_processor.constant.index_type import IndexTechniqueType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import ( + AppDatasetJoin, + Dataset, + DatasetAutoDisableLog, + DatasetCollectionBinding, + DatasetPermission, + DatasetPermissionEnum, +) +from models.enums import DataSourceType +from services.dataset_service import DatasetCollectionBindingService, DatasetPermissionService, DatasetService +from services.errors.account import NoPermissionError + + +class DatasetPermissionIntegrationFactory: + @staticmethod + def create_account_with_tenant( + db_session_with_containers: Session, + role: TenantAccountRole = TenantAccountRole.OWNER, + ) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.role = role + account._current_tenant = tenant + return account, tenant + + @staticmethod + def create_account_in_tenant( + db_session_with_containers: Session, + tenant: Tenant, + role: TenantAccountRole = TenantAccountRole.EDITOR, + ) -> Account: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.role = role + account._current_tenant = tenant + return account + + @staticmethod + def create_dataset( + db_session_with_containers: Session, + *, + tenant_id: str, + created_by: str, + name: str | None = None, + permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, + indexing_technique: str | None = IndexTechniqueType.HIGH_QUALITY, + enable_api: bool = True, + ) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=name or f"dataset-{uuid4()}", + description="desc", + data_source_type=DataSourceType.UPLOAD_FILE, + indexing_technique=indexing_technique, + created_by=created_by, + provider="vendor", + permission=permission, + retrieval_model={"top_k": 2}, + ) + dataset.enable_api = enable_api + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_dataset_permission( + db_session_with_containers: Session, + *, + dataset_id: str, + tenant_id: str, + account_id: str, + ) -> DatasetPermission: + permission = DatasetPermission( + dataset_id=dataset_id, + tenant_id=tenant_id, + account_id=account_id, + has_permission=True, + ) + db_session_with_containers.add(permission) + db_session_with_containers.commit() + return permission + + @staticmethod + def create_app_dataset_join( + db_session_with_containers: Session, + *, + dataset_id: str, + ) -> AppDatasetJoin: + join = AppDatasetJoin( + app_id=str(uuid4()), + dataset_id=dataset_id, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + return join + + @staticmethod + def create_collection_binding( + db_session_with_containers: Session, + *, + provider_name: str, + model_name: str, + collection_type: str = "dataset", + ) -> DatasetCollectionBinding: + binding = DatasetCollectionBinding( + provider_name=provider_name, + model_name=model_name, + collection_name=f"collection_{uuid4().hex}", + type=collection_type, + ) + db_session_with_containers.add(binding) + db_session_with_containers.commit() + return binding + + @staticmethod + def create_auto_disable_log( + db_session_with_containers: Session, + *, + tenant_id: str, + dataset_id: str, + document_id: str, + ) -> DatasetAutoDisableLog: + log = DatasetAutoDisableLog( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + ) + db_session_with_containers.add(log) + db_session_with_containers.commit() + return log + + +class TestDatasetServicePermissionsAndLifecycle: + def test_delete_dataset_returns_false_when_dataset_is_missing(self, db_session_with_containers: Session): + owner, _tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + + result = DatasetService.delete_dataset(str(uuid4()), user=owner) + + assert result is False + + def test_delete_dataset_checks_permission_and_deletes_dataset(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + with patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal: + result = DatasetService.delete_dataset(dataset.id, user=owner) + + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + send_deleted_signal.assert_called_once_with(dataset) + + def test_dataset_use_check_returns_true_when_join_exists(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + DatasetPermissionIntegrationFactory.create_app_dataset_join( + db_session_with_containers, + dataset_id=dataset.id, + ) + + assert DatasetService.dataset_use_check(dataset.id) is True + + def test_dataset_use_check_returns_false_when_join_missing(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + assert DatasetService.dataset_use_check(dataset.id) is False + + def test_check_dataset_permission_rejects_cross_tenant_access(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + outsider, _other_tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant( + db_session_with_containers + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, outsider) + + def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_rejects_partial_team_user_without_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_allows_partial_team_creator(self, db_session_with_containers: Session): + creator, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=creator.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + DatasetService.check_dataset_permission(dataset, creator) + + def test_check_dataset_permission_allows_partial_team_member_with_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member.id, + ) + + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_operator_permission_rejects_only_me_for_non_creator( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_check_dataset_operator_permission_rejects_partial_team_without_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_check_dataset_operator_permission_allows_partial_team_with_binding( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + operator = DatasetPermissionIntegrationFactory.create_account_in_tenant( + db_session_with_containers, + tenant, + role=TenantAccountRole.EDITOR, + ) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=operator.id, + ) + + DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers: Flask): + with flask_app_with_containers.app_context(): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetService.update_dataset_api_status(str(uuid4()), True) + + def test_update_dataset_api_status_requires_current_user_id(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + enable_api=False, + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(id=None)): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.update_dataset_api_status(dataset.id, True) + + def test_update_dataset_api_status_updates_fields_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + enable_api=False, + ) + now = datetime(2026, 4, 14, 18, 0, 0) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.naive_utc_now", return_value=now), + ): + DatasetService.update_dataset_api_status(dataset.id, True) + + db_session_with_containers.refresh(dataset) + assert dataset.enable_api is True + assert dataset.updated_by == owner.id + assert dataset.updated_at == now + + def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled( + self, db_session_with_containers: Session + ): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan="professional")) + ) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + ): + result = DatasetService.get_dataset_auto_disable_logs(str(uuid4())) + + assert result == {"document_ids": [], "count": 0} + + def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + ) + DatasetPermissionIntegrationFactory.create_auto_disable_log( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=str(uuid4()), + ) + DatasetPermissionIntegrationFactory.create_auto_disable_log( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=str(uuid4()), + ) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan="professional")) + ) + + with ( + patch("services.dataset_service.current_user", owner), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + ): + result = DatasetService.get_dataset_auto_disable_logs(dataset.id) + + assert result["count"] == 2 + assert len(result["document_ids"]) == 2 + + +class TestDatasetCollectionBindingServiceIntegration: + def test_get_dataset_collection_binding_returns_existing_binding(self, db_session_with_containers: Session): + binding = DatasetPermissionIntegrationFactory.create_collection_binding( + db_session_with_containers, + provider_name="provider", + model_name="model", + ) + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") + + assert result.id == binding.id + + def test_get_dataset_collection_binding_creates_binding_when_missing(self, db_session_with_containers: Session): + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "missing-model") + + persisted = db_session_with_containers.get(DatasetCollectionBinding, result.id) + assert persisted is not None + assert persisted.provider_name == "provider" + assert persisted.model_name == "missing-model" + assert persisted.type == "dataset" + assert persisted.collection_name + + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers: Flask): + with flask_app_with_containers.app_context(): + with pytest.raises(ValueError, match="Dataset collection binding not found"): + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4())) + + def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self, db_session_with_containers: Session): + binding = DatasetPermissionIntegrationFactory.create_collection_binding( + db_session_with_containers, + provider_name="provider", + model_name="model", + ) + + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id) + + assert result.id == binding.id + + +class TestDatasetPermissionServiceIntegration: + def test_get_dataset_partial_member_list_returns_scalar_results(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member_a.id, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member_b.id, + ) + + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + assert set(result) == {member_a.id, member_b.id} + + def test_update_partial_member_list_replaces_permissions_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member_a = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + member_b = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + stale_member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=stale_member.id, + ) + + DatasetPermissionService.update_partial_member_list( + tenant.id, + dataset.id, + [{"user_id": member_a.id}, {"user_id": member_b.id}], + ) + + permissions = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() + assert {permission.account_id for permission in permissions} == {member_a.id, member_b.id} + + def test_check_permission_requires_dataset_editor(self): + user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) + + with pytest.raises(NoPermissionError, match="does not have permission"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ALL_TEAM, []) + + def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) + + with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ONLY_ME, []) + + def test_check_permission_requires_partial_member_list_for_partial_members_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with pytest.raises(ValueError, match="Partial member list is required"): + DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.PARTIAL_TEAM, []) + + def test_check_permission_rejects_dataset_operator_member_list_changes(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + with pytest.raises(ValueError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission( + user, + dataset, + DatasetPermissionEnum.PARTIAL_TEAM, + [{"user_id": "user-2"}], + ) + + def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + DatasetPermissionService.check_permission( + user, + dataset, + DatasetPermissionEnum.PARTIAL_TEAM, + [{"user_id": "user-1"}], + ) + + def test_clear_partial_member_list_deletes_permissions_and_commits(self, db_session_with_containers: Session): + owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) + member = DatasetPermissionIntegrationFactory.create_account_in_tenant(db_session_with_containers, tenant) + dataset = DatasetPermissionIntegrationFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionIntegrationFactory.create_dataset_permission( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant.id, + account_id=member.id, + ) + + DatasetPermissionService.clear_partial_member_list(dataset.id) + + remaining = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() + assert remaining == [] diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index a814466e14..ac0483a45d 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -1,13 +1,21 @@ +import json from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType -from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.dataset import Dataset, ExternalKnowledgeBindings +from graphon.model_runtime.entities.model_entities import ModelType +from models.account import ( + Account, + AccountStatus, + Tenant, + TenantAccountJoin, + TenantAccountRole, + TenantStatus, +) +from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings from models.enums import DataSourceType from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -25,12 +33,12 @@ class DatasetUpdateTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() - tenant = Tenant(name=f"tenant-{account.id}", status="normal") + tenant = Tenant(name=f"tenant-{account.id}", status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -103,6 +111,34 @@ class DatasetUpdateTestDataFactory: db_session_with_containers.commit() return binding + @staticmethod + def create_external_knowledge_api( + db_session_with_containers: Session, + tenant_id: str, + created_by: str, + api_id: str | None = None, + name: str = "test-api", + ) -> ExternalKnowledgeApis: + """Create a real external knowledge API template for tenant-scoped update validation.""" + external_api = ExternalKnowledgeApis( + tenant_id=tenant_id, + created_by=created_by, + updated_by=created_by, + name=name, + description="test description", + settings=json.dumps( + { + "endpoint": "https://example.com", + "api_key": "test-api-key", + } + ), + ) + if api_id is not None: + external_api.id = api_id + db_session_with_containers.add(external_api) + db_session_with_containers.commit() + return external_api + class TestDatasetServiceUpdateDataset: """ @@ -138,6 +174,11 @@ class TestDatasetServiceUpdateDataset: ) binding_id = binding.id db_session_with_containers.expunge(binding) + external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api( + db_session_with_containers, + tenant_id=tenant.id, + created_by=user.id, + ) update_data = { "name": "new_name", @@ -145,7 +186,7 @@ class TestDatasetServiceUpdateDataset: "external_retrieval_model": "new_model", "permission": "only_me", "external_knowledge_id": "new_knowledge_id", - "external_knowledge_api_id": str(uuid4()), + "external_knowledge_api_id": external_api.id, } result = DatasetService.update_dataset(dataset.id, update_data, user) @@ -218,11 +259,16 @@ class TestDatasetServiceUpdateDataset: created_by=user.id, provider="external", ) + external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api( + db_session_with_containers, + tenant_id=tenant.id, + created_by=user.id, + ) update_data = { "name": "new_name", "external_knowledge_id": "knowledge_id", - "external_knowledge_api_id": str(uuid4()), + "external_knowledge_api_id": external_api.id, } with pytest.raises(ValueError) as context: diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index c8f04e9215..69c39b8bfb 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,10 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 -from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select +from sqlalchemy.orm import Session +from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion @@ -46,7 +47,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.commit() return run - def _create_archive_log(self, db_session_with_containers, *, run: WorkflowRun) -> None: + def _create_archive_log(self, db_session_with_containers: Session, *, run: WorkflowRun) -> None: archive_log = WorkflowArchiveLog( tenant_id=run.tenant_id, app_id=run.app_id, @@ -72,7 +73,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.add(archive_log) db_session_with_containers.commit() - def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers): + def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers: Session): deleter = ArchivedWorkflowRunDeletion() missing_run_id = str(uuid4()) @@ -81,7 +82,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == f"Workflow run {missing_run_id} not found" - def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers): + def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers: Session): tenant_id = str(uuid4()) run = self._create_workflow_run( db_session_with_containers, @@ -95,7 +96,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == f"Workflow run {run.id} is not archived" - def test_delete_batch_uses_repo(self, db_session_with_containers): + def test_delete_batch_uses_repo(self, db_session_with_containers: Session): tenant_id = str(uuid4()) base_time = datetime.now(UTC) run1 = self._create_workflow_run(db_session_with_containers, tenant_id=tenant_id, created_at=base_time) @@ -124,7 +125,7 @@ class TestArchivedWorkflowRunDeletion: ).all() assert remaining_runs == [] - def test_delete_run_calls_repo(self, db_session_with_containers): + def test_delete_run_calls_repo(self, db_session_with_containers: Session): tenant_id = str(uuid4()) run = self._create_workflow_run( db_session_with_containers, @@ -142,7 +143,7 @@ class TestArchivedWorkflowRunDeletion: deleted_run = db_session_with_containers.get(WorkflowRun, run_id) assert deleted_run is None - def test_delete_run_dry_run(self, db_session_with_containers): + def test_delete_run_dry_run(self, db_session_with_containers: Session): """Dry run should return success without actually deleting.""" tenant_id = str(uuid4()) run = self._create_workflow_run( @@ -161,7 +162,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expire_all() assert db_session_with_containers.get(WorkflowRun, run_id) is not None - def test_delete_run_exception_returns_error(self, db_session_with_containers): + def test_delete_run_exception_returns_error(self, db_session_with_containers: Session): """Exception during deletion should return failure result.""" from unittest.mock import MagicMock, patch @@ -183,7 +184,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == "Database error" - def test_delete_by_run_id_success(self, db_session_with_containers): + def test_delete_by_run_id_success(self, db_session_with_containers: Session): """Successfully delete an archived workflow run by ID.""" tenant_id = str(uuid4()) base_time = datetime.now(UTC) @@ -202,7 +203,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expunge_all() assert db_session_with_containers.get(WorkflowRun, run_id) is None - def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers): + def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers: Session): """_get_workflow_run_repo should return a cached repo on subsequent calls.""" deleter = ArchivedWorkflowRunDeletion() diff --git a/api/tests/test_containers_integration_tests/services/test_end_user_service.py b/api/tests/test_containers_integration_tests/services/test_end_user_service.py index cafabc939b..074d448aab 100644 --- a/api/tests/test_containers_integration_tests/services/test_end_user_service.py +++ b/api/tests/test_containers_integration_tests/services/test_end_user_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin @@ -102,7 +103,7 @@ class TestEndUserServiceGetOrCreateEndUser: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers, factory): + def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers: Session, factory): """Test getting or creating end user with custom user_id.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -118,7 +119,7 @@ class TestEndUserServiceGetOrCreateEndUser: assert result.type == InvokeFrom.SERVICE_API assert result.is_anonymous is False - def test_get_or_create_end_user_without_user_id(self, db_session_with_containers, factory): + def test_get_or_create_end_user_without_user_id(self, db_session_with_containers: Session, factory): """Test getting or creating end user without user_id uses default session.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -131,7 +132,7 @@ class TestEndUserServiceGetOrCreateEndUser: # Verify _is_anonymous is set correctly (property always returns False) assert result._is_anonymous is True - def test_get_existing_end_user(self, db_session_with_containers, factory): + def test_get_existing_end_user(self, db_session_with_containers: Session, factory): """Test retrieving an existing end user.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -167,7 +168,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_create_end_user_service_api_type(self, db_session_with_containers, factory): + def test_create_end_user_service_api_type(self, db_session_with_containers: Session, factory): """Test creating new end user with SERVICE_API type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -189,7 +190,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.app_id == app_id assert result.session_id == user_id - def test_create_end_user_web_app_type(self, db_session_with_containers, factory): + def test_create_end_user_web_app_type(self, db_session_with_containers: Session, factory): """Test creating new end user with WEB_APP type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -209,7 +210,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.type == InvokeFrom.WEB_APP @patch("services.end_user_service.logger") - def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers, factory): + def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers: Session, factory): """Test upgrading legacy end user with different type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -243,7 +244,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert "Upgrading legacy EndUser" in log_call @patch("services.end_user_service.logger") - def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers, factory): + def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers: Session, factory): """Test retrieving existing end user with matching type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -272,7 +273,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.type == InvokeFrom.SERVICE_API mock_logger.info.assert_not_called() - def test_create_anonymous_user_with_default_session(self, db_session_with_containers, factory): + def test_create_anonymous_user_with_default_session(self, db_session_with_containers: Session, factory): """Test creating anonymous user when user_id is None.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -293,7 +294,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result._is_anonymous is True assert result.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers, factory): + def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers: Session, factory): """Test that query ordering prioritizes records with matching type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -328,7 +329,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.id == matching.id assert result.id != non_matching.id - def test_external_user_id_matches_session_id(self, db_session_with_containers, factory): + def test_external_user_id_matches_session_id(self, db_session_with_containers: Session, factory): """Test that external_user_id is set to match session_id.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -357,7 +358,9 @@ class TestEndUserServiceGetOrCreateEndUserByType: InvokeFrom.DEBUGGER, ], ) - def test_create_end_user_with_different_invoke_types(self, db_session_with_containers, invoke_type, factory): + def test_create_end_user_with_different_invoke_types( + self, db_session_with_containers: Session, invoke_type, factory + ): """Test creating end users with different InvokeFrom types.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -385,7 +388,7 @@ class TestEndUserServiceGetEndUserById: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers, factory): + def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers: Session, factory): app = factory.create_app_and_account(db_session_with_containers) existing_user = factory.create_end_user( db_session_with_containers, @@ -404,7 +407,7 @@ class TestEndUserServiceGetEndUserById: assert result is not None assert result.id == existing_user.id - def test_get_end_user_by_id_returns_none(self, db_session_with_containers, factory): + def test_get_end_user_by_id_returns_none(self, db_session_with_containers: Session, factory): app = factory.create_app_and_account(db_session_with_containers) result = EndUserService.get_end_user_by_id( @@ -423,7 +426,7 @@ class TestEndUserServiceCreateBatch: def factory(self): return TestEndUserServiceFactory() - def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3): + def _create_multiple_apps(self, db_session_with_containers: Session, factory, count: int = 3): """Create multiple apps under the same tenant.""" first_app = factory.create_app_and_account(db_session_with_containers) tenant_id = first_app.tenant_id @@ -452,13 +455,13 @@ class TestEndUserServiceCreateBatch: all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all() return tenant_id, all_apps - def test_create_batch_empty_app_ids(self, db_session_with_containers): + def test_create_batch_empty_app_ids(self, db_session_with_containers: Session): result = EndUserService.create_end_user_batch( type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1" ) assert result == {} - def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory): + def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) app_ids = [a.id for a in apps] user_id = f"user-{uuid4()}" @@ -473,7 +476,7 @@ class TestEndUserServiceCreateBatch: assert result[app_id].session_id == user_id assert result[app_id].type == InvokeFrom.SERVICE_API - def test_create_batch_default_session_id(self, db_session_with_containers, factory): + def test_create_batch_default_session_id(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [a.id for a in apps] @@ -486,7 +489,7 @@ class TestEndUserServiceCreateBatch: assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID assert end_user._is_anonymous is True - def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory): + def test_create_batch_deduplicate_app_ids(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id] user_id = f"user-{uuid4()}" @@ -497,7 +500,7 @@ class TestEndUserServiceCreateBatch: assert len(result) == 2 - def test_create_batch_returns_existing_users(self, db_session_with_containers, factory): + def test_create_batch_returns_existing_users(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [a.id for a in apps] user_id = f"user-{uuid4()}" @@ -516,7 +519,7 @@ class TestEndUserServiceCreateBatch: for app_id in app_ids: assert first_result[app_id].id == second_result[app_id].id - def test_create_batch_partial_existing_users(self, db_session_with_containers, factory): + def test_create_batch_partial_existing_users(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) user_id = f"user-{uuid4()}" @@ -545,7 +548,7 @@ class TestEndUserServiceCreateBatch: "invoke_type", [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER], ) - def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory): + def test_create_batch_all_invoke_types(self, db_session_with_containers: Session, invoke_type, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1) user_id = f"user-{uuid4()}" diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index b3e7dd2a59..f78aeaf984 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from services.feature_service import ( @@ -81,7 +82,7 @@ class TestFeatureService: fake = Faker() return fake.uuid4() - def test_get_features_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful feature retrieval with billing and enterprise enabled. @@ -156,7 +157,7 @@ class TestFeatureService: tenant_id ) - def test_get_features_sandbox_plan(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_sandbox_plan(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test feature retrieval for sandbox plan with specific limitations. @@ -222,7 +223,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) - def test_get_knowledge_rate_limit_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_knowledge_rate_limit_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful knowledge rate limit retrieval with billing enabled. @@ -255,7 +258,7 @@ class TestFeatureService: tenant_id ) - def test_get_system_features_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful system features retrieval with enterprise and marketplace enabled. @@ -274,6 +277,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = True mock_config.ALLOW_REGISTER = False mock_config.ALLOW_CREATE_WORKSPACE = False mock_config.MAIL_TYPE = "smtp" @@ -298,6 +302,7 @@ class TestFeatureService: # Verify authentication settings assert result.enable_email_code_login is True assert result.enable_email_password_login is False + assert result.enable_collaboration_mode is True assert result.is_allow_register is False assert result.is_allow_create_workspace is False @@ -330,7 +335,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_unauthenticated( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval for an unauthenticated user. @@ -384,7 +391,9 @@ class TestFeatureService: # Marketplace should be visible assert result.enable_marketplace is True - def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_basic_config( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with basic configuration (no enterprise). @@ -401,6 +410,7 @@ class TestFeatureService: mock_config.ENABLE_EMAIL_CODE_LOGIN = True mock_config.ENABLE_EMAIL_PASSWORD_LOGIN = True mock_config.ENABLE_SOCIAL_OAUTH_LOGIN = False + mock_config.ENABLE_COLLABORATION_MODE = False mock_config.ALLOW_REGISTER = True mock_config.ALLOW_CREATE_WORKSPACE = True mock_config.MAIL_TYPE = "smtp" @@ -422,6 +432,7 @@ class TestFeatureService: assert result.enable_email_code_login is True assert result.enable_email_password_login is True assert result.enable_social_oauth_login is False + assert result.enable_collaboration_mode is False assert result.is_allow_register is True assert result.is_allow_create_workspace is True assert result.is_email_setup is True @@ -432,7 +443,9 @@ class TestFeatureService: # Verify plugin package size (uses default value from dify_config) assert result.max_plugin_package_size == 15728640 - def test_get_features_billing_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_billing_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval when billing is disabled. @@ -488,7 +501,7 @@ class TestFeatureService: assert result.webapp_copyright_enabled is False def test_get_knowledge_rate_limit_billing_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test knowledge rate limit retrieval when billing is disabled. @@ -519,7 +532,9 @@ class TestFeatureService: # Verify no billing service calls mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_not_called() - def test_get_features_enterprise_only(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_enterprise_only( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with enterprise enabled but billing disabled. @@ -579,7 +594,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_not_called() def test_get_system_features_enterprise_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval when enterprise is disabled. @@ -636,7 +651,7 @@ class TestFeatureService: # Verify no enterprise service calls mock_external_service_dependencies["enterprise_service"].get_info.assert_not_called() - def test_get_features_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test feature retrieval without tenant ID (billing disabled). @@ -682,7 +697,9 @@ class TestFeatureService: # Verify no billing service calls mock_external_service_dependencies["billing_service"].get_info.assert_not_called() - def test_get_features_partial_billing_info(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_partial_billing_info( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with partial billing information. @@ -742,7 +759,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) - def test_get_features_edge_case_vector_space(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_vector_space( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case vector space configuration. @@ -803,7 +822,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_webapp_auth( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case webapp auth configuration. @@ -859,7 +878,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_members_quota(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_members_quota( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case members quota configuration. @@ -920,7 +941,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_plugin_installation_permission_scopes( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with different plugin installation permission scopes. @@ -1019,7 +1040,7 @@ class TestFeatureService: assert result.plugin_installation_permission.restrict_to_marketplace_only is True def test_get_features_workspace_members_missing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval when workspace members info is missing from enterprise. @@ -1060,7 +1081,9 @@ class TestFeatureService: tenant_id ) - def test_get_system_features_license_inactive(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_license_inactive( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with inactive license. @@ -1113,7 +1136,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_system_features_partial_enterprise_info( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with partial enterprise information. @@ -1182,7 +1205,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_limits(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_limits( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case limit values. @@ -1240,7 +1265,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_protocols( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case protocol values. @@ -1293,7 +1318,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_education(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_education( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case education configuration. @@ -1349,7 +1376,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_license_limitation_model_is_available( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test LicenseLimitationModel.is_available method with various scenarios. @@ -1390,7 +1417,7 @@ class TestFeatureService: assert exact_limit.is_available(3) is True def test_get_features_workspace_members_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval when workspace members are disabled in enterprise. @@ -1429,7 +1456,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with(tenant_id) - def test_get_system_features_license_expired(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_license_expired( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with expired license. @@ -1482,7 +1511,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_docs_processing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case document processing configuration. @@ -1540,7 +1569,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_branding( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case branding configuration. @@ -1602,7 +1631,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_annotation_quota( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case annotation quota configuration. @@ -1664,7 +1693,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_features_edge_case_documents_upload( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case documents upload settings. @@ -1729,7 +1758,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_license_lost( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features with lost license status. @@ -1780,7 +1809,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_education_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with education feature disabled. diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index d82933ccb9..3dcd6586e2 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -13,6 +13,12 @@ from models.model import App, Conversation, Message from services.feedback_service import FeedbackService +def _execute_result(rows): + result = mock.Mock() + result.all.return_value = rows + return result + + class TestFeedbackService: """Test FeedbackService methods.""" @@ -81,25 +87,17 @@ class TestFeedbackService: def test_export_feedbacks_csv_format(self, mock_db_session, sample_data): """Test exporting feedback data in CSV format.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ) + ] + ) # Test CSV export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -120,25 +118,17 @@ class TestFeedbackService: def test_export_feedbacks_json_format(self, mock_db_session, sample_data): """Test exporting feedback data in JSON format.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ) + ] + ) # Test JSON export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -157,25 +147,17 @@ class TestFeedbackService: def test_export_feedbacks_with_filters(self, mock_db_session, sample_data): """Test exporting feedback with various filters.""" - - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ) + ] + ) # Test with filters result = FeedbackService.export_feedbacks( @@ -193,17 +175,7 @@ class TestFeedbackService: def test_export_feedbacks_no_data(self, mock_db_session, sample_data): """Test exporting feedback when no data exists.""" - - # Setup mock query result with no data - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result([]) result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -251,24 +223,17 @@ class TestFeedbackService: created_at=datetime(2024, 1, 1, 10, 0, 0), ) - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - long_message, - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + long_message, + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ) + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") @@ -309,24 +274,17 @@ class TestFeedbackService: created_at=datetime(2024, 1, 1, 10, 0, 0), ) - # Setup mock query result - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - chinese_feedback, - chinese_message, - sample_data["conversation"], - sample_data["app"], - None, # No account for user feedback - ) - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + chinese_feedback, + chinese_message, + sample_data["conversation"], + sample_data["app"], + None, + ) + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="csv") @@ -339,32 +297,24 @@ class TestFeedbackService: def test_export_feedbacks_emoji_ratings(self, mock_db_session, sample_data): """Test that rating emojis are properly formatted in export.""" - - # Setup mock query result with both like and dislike feedback - mock_query = mock.Mock() - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [ - ( - sample_data["user_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["user_feedback"].from_account, - ), - ( - sample_data["admin_feedback"], - sample_data["message"], - sample_data["conversation"], - sample_data["app"], - sample_data["admin_feedback"].from_account, - ), - ] - - mock_db_session.execute.return_value = mock_query + mock_db_session.execute.return_value = _execute_result( + [ + ( + sample_data["user_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["user_feedback"].from_account, + ), + ( + sample_data["admin_feedback"], + sample_data["message"], + sample_data["conversation"], + sample_data["app"], + sample_data["admin_feedback"].from_account, + ), + ] + ) # Test export result = FeedbackService.export_feedbacks(app_id=sample_data["app"].id, format_type="json") diff --git a/api/tests/unit_tests/services/test_hit_testing_service.py b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py similarity index 51% rename from api/tests/unit_tests/services/test_hit_testing_service.py rename to api/tests/test_containers_integration_tests/services/test_hit_testing_service.py index 80e9729f5b..f332ba05ec 100644 --- a/api/tests/unit_tests/services/test_hit_testing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_hit_testing_service.py @@ -1,239 +1,193 @@ +from __future__ import annotations + import json from typing import Any, cast from unittest.mock import ANY, MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.models.document import Document -from models.dataset import Dataset +from models.dataset import Dataset, DatasetQuery from services.hit_testing_service import HitTestingService -class TestHitTestingService: - """Test suite for HitTestingService""" +def _create_dataset(db_session: Session, *, provider: str = "vendor", **kwargs: Any) -> Dataset: + tenant_id = str(uuid4()) + created_by = str(uuid4()) + ds = Dataset( + tenant_id=kwargs.get("tenant_id", tenant_id), + name=kwargs.get("name", "test-dataset"), + created_by=kwargs.get("created_by", created_by), + provider=provider, + ) + db_session.add(ds) + db_session.commit() + db_session.refresh(ds) + return ds - # ===== Utility Method Tests ===== + +class TestHitTestingService: + # ── Utility methods (pure logic, no DB) ──────────────────────────── def test_escape_query_for_search_should_escape_double_quotes(self): - """Test that escape_query_for_search escapes double quotes correctly""" - # Arrange query = 'test "query" with quotes' - expected = 'test \\"query\\" with quotes' - - # Act result = HitTestingService.escape_query_for_search(query) - - # Assert - assert result == expected + assert result == 'test \\"query\\" with quotes' def test_hit_testing_args_check_should_pass_with_valid_query(self): - """Test that hit_testing_args_check passes with a valid query""" - # Arrange - args = {"query": "valid query"} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"query": "valid query"}) def test_hit_testing_args_check_should_pass_with_valid_attachments(self): - """Test that hit_testing_args_check passes with valid attachment_ids""" - # Arrange - args = {"attachment_ids": ["id1", "id2"]} - - # Act & Assert (should not raise) - HitTestingService.hit_testing_args_check(args) + HitTestingService.hit_testing_args_check({"attachment_ids": ["id1", "id2"]}) def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): - """Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" - # Arrange - args = {} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query or attachment_ids is required" in str(exc_info.value) + with pytest.raises(ValueError, match="Query or attachment_ids is required"): + HitTestingService.hit_testing_args_check({}) def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): - """Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" - # Arrange - args = {"query": "a" * 251} - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Query cannot exceed 250 characters" in str(exc_info.value) + with pytest.raises(ValueError, match="Query cannot exceed 250 characters"): + HitTestingService.hit_testing_args_check({"query": "a" * 251}) def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): - """Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" - # Arrange - args = {"attachment_ids": "not a list"} + with pytest.raises(ValueError, match="Attachment_ids must be a list"): + HitTestingService.hit_testing_args_check({"attachment_ids": "not a list"}) - # Act & Assert - with pytest.raises(ValueError) as exc_info: - HitTestingService.hit_testing_args_check(args) - assert "Attachment_ids must be a list" in str(exc_info.value) - - # ===== Response Formatting Tests ===== + # ── Response formatting ──────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") def test_compact_retrieve_response_should_format_correctly(self, mock_format): - """Test that compact_retrieve_response formats the response correctly""" - # Arrange query = "test query" mock_doc = MagicMock(spec=Document) - documents = [mock_doc] mock_record = MagicMock() mock_record.model_dump.return_value = {"content": "formatted content"} mock_format.return_value = [mock_record] - # Act - result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents)) + result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, [mock_doc])) - # Assert assert cast(dict[str, Any], result["query"])["content"] == query assert len(result["records"]) == 1 assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" - mock_format.assert_called_once_with(documents) + mock_format.assert_called_once_with([mock_doc]) - def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): - """Test that compact_external_retrieve_response returns records when dataset provider is external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "external" - query = "test query" + def test_compact_external_retrieve_response_should_return_records_for_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") documents = [ {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, ] - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, "test query", documents) + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert len(result["records"]) == 2 assert cast(dict[str, Any], result["records"][0])["content"] == "c1" assert cast(dict[str, Any], result["records"][1])["title"] == "t2" - def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): - """Test that compact_external_retrieve_response returns empty records for non-external provider""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" - documents = [{"content": "c1"}] + def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider( + self, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="vendor") - # Act - result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + result = cast( + dict[str, Any], + HitTestingService.compact_external_retrieve_response(dataset, "test query", [{"content": "c1"}]), + ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== External Retrieve Tests ===== + # ── External retrieve (real DB) ──────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): - """Test that external_retrieve successfully retrieves from external provider and commits query""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - dataset.provider = "external" - query = 'test "query"' + def test_external_retrieve_should_succeed_for_external_provider( + self, mock_ext_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers, provider="external") + account_id = str(uuid4()) account = MagicMock() - account.id = "account_id" - + account.id = account_id mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.external_retrieve( dataset=dataset, - query=query, + query='test "query"', account=account, external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == 'test "query"' assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" - - # Verify call to RetrievalService.external_retrieve with escaped query mock_ext_retrieve.assert_called_once_with( - dataset_id="dataset_id", + dataset_id=dataset.id, query='test \\"query\\"', external_retrieval_model={"model": "test"}, metadata_filtering_conditions={"key": "val"}, ) - # Verify DatasetQuery record was added and committed - mock_add.assert_called_once() - mock_commit.assert_called_once() + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 - def test_external_retrieve_should_return_empty_for_non_external_provider(self): - """Test that external_retrieve returns empty results immediately if provider is not external""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.provider = "not_external" - query = "test query" + def test_external_retrieve_should_return_empty_for_non_external_provider(self, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers, provider="vendor") account = MagicMock() - # Act - result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account)) + result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, "test query", account)) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" assert result["records"] == [] - # ===== Retrieve Tests ===== + # ── Retrieve (real DB) ───────────────────────────────────────────── @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve uses default model when retrieval_model is not provided""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" + def test_retrieve_should_use_default_model_when_none_provided( + self, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) dataset.retrieval_model = None - query = "test query" account = MagicMock() - account.id = "account_id" - + account.id = str(uuid4()) mock_retrieve.return_value = [] - # Act + before_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + result = cast( dict[str, Any], HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} + dataset=dataset, query="test query", account=account, retrieval_model=None, external_retrieval_model={} ), ) - # Assert - assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["query"])["content"] == "test query" mock_retrieve.assert_called_once() - # Verify top_k from default_retrieval_model (4) assert mock_retrieve.call_args.kwargs["top_k"] == 4 - mock_commit.assert_called_once() + + db_session_with_containers.expire_all() + after_count = db_session_with_containers.scalar(select(func.count()).select_from(DatasetQuery)) or 0 + assert after_count == before_count + 1 @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): - """Test that retrieve correctly calls metadata filtering when conditions are present""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_metadata_filtering( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "semantic_search", @@ -242,29 +196,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response - mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string") + mock_get_meta.return_value = ({dataset.id: ["doc_id1"]}, "condition_string") mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_get_meta.assert_called_once() mock_retrieve.assert_called_once() assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") - def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): - """Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_return_empty_if_metadata_filtering_fails( + self, mock_get_meta, mock_retrieve, db_session_with_containers: Session + ): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() retrieval_model = { @@ -274,37 +226,27 @@ class TestHitTestingService: "reranking_enable": False, "score_threshold_enabled": False, } - - # Mock metadata filtering response: condition returned but no IDs mock_get_meta.return_value = ({}, "condition_string") - # Act result = cast( dict[str, Any], HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, ), ) - # Assert assert result["records"] == [] mock_retrieve.assert_not_called() @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve handles attachment_ids and adds them to DatasetQuery""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_attachments(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) attachment_ids = ["att1", "att2"] retrieval_model = { @@ -315,21 +257,19 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( dataset=dataset, - query=query, + query="test query", account=account, retrieval_model=retrieval_model, external_retrieval_model={}, attachment_ids=attachment_ids, ) - # Assert mock_retrieve.assert_called_once_with( retrieval_method=ANY, - dataset_id="dataset_id", - query=query, + dataset_id=dataset.id, + query="test query", attachment_ids=attachment_ids, top_k=4, score_threshold=0.0, @@ -338,26 +278,27 @@ class TestHitTestingService: weights=None, document_ids_filter=None, ) - # Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images) - # The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) - called_query = mock_add.call_args[0][0] - query_content = json.loads(called_query.content) + + # Verify DatasetQuery was persisted with correct content structure + db_session_with_containers.expire_all() + latest = db_session_with_containers.scalar( + select(DatasetQuery) + .where(DatasetQuery.dataset_id == dataset.id) + .order_by(DatasetQuery.created_at.desc()) + .limit(1) + ) + assert latest is not None + query_content = json.loads(latest.content) assert len(query_content) == 3 # 1 text + 2 images assert query_content[0]["content_type"] == "text_query" assert query_content[1]["content_type"] == "image_query" assert query_content[1]["content"] == "att1" @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") - @patch("extensions.ext_database.db.session.add") - @patch("extensions.ext_database.db.session.commit") - def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve): - """Test that retrieve passes reranking and threshold parameters correctly""" - # Arrange - dataset = MagicMock(spec=Dataset) - dataset.id = "dataset_id" - query = "test query" + def test_retrieve_should_handle_reranking_and_threshold(self, mock_retrieve, db_session_with_containers: Session): + dataset = _create_dataset(db_session_with_containers) account = MagicMock() - account.id = "account_id" + account.id = str(uuid4()) retrieval_model = { "search_method": "hybrid_search", @@ -371,12 +312,14 @@ class TestHitTestingService: } mock_retrieve.return_value = [] - # Act HitTestingService.retrieve( - dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + dataset=dataset, + query="test query", + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, ) - # Assert mock_retrieve.assert_called_once() kwargs = mock_retrieve.call_args.kwargs assert kwargs["score_threshold"] == 0.5 diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index c46b8fba0b..80f9083e81 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,15 +3,15 @@ import uuid from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.human_input.entities import HumanInputNodeData -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, ) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 0f252515f7..ce63e7a71a 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -5,17 +5,18 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.runtime import VariablePool from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session from configs import dify_config -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, MemberRecipient, ) +from graphon.runtime import VariablePool from models.account import Account, TenantAccountJoin from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( @@ -88,7 +89,7 @@ class TestDeliveryTestRegistry: with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): registry.dispatch(context=context, method=method) - def test_default(self, flask_app_with_containers, db_session_with_containers): + def test_default(self, flask_app_with_containers, db_session_with_containers: Session): registry = DeliveryTestRegistry.default() assert len(registry._handlers) == 1 assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) @@ -260,7 +261,7 @@ class TestEmailDeliveryTestHandler: ) assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] - def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers): + def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) account = Account(name="Test User", email="member@example.com") db_session_with_containers.add(account) @@ -282,7 +283,7 @@ class TestEmailDeliveryTestHandler: ) assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"] - def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers): + def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com") account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com") diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 2340dd2a03..1a1efe0337 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -8,11 +8,11 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.file import FileType from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client +from graphon.file import FileType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ( ConversationFromSource, @@ -165,7 +165,7 @@ class TestMessagesCleanServiceIntegration: return app - def _create_conversation(self, db_session_with_containers: Session, app): + def _create_conversation(self, db_session_with_containers: Session, app: App): """Helper to create a conversation.""" conversation = Conversation( app_id=app.id, diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py index b55a19eaa9..fffa82bf5c 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest from sqlalchemy import select +from sqlalchemy.orm import Session from models.dataset import Dataset, DatasetMetadataBinding, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -65,7 +66,7 @@ class TestMetadataPartialUpdate: yield account def test_partial_update_merges_metadata( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -92,7 +93,7 @@ class TestMetadataPartialUpdate: assert updated_doc.doc_metadata["new_key"] == "new_value" def test_full_update_replaces_metadata( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -119,7 +120,7 @@ class TestMetadataPartialUpdate: assert "existing_key" not in updated_doc.doc_metadata def test_partial_update_skips_existing_binding( - self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, user_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -159,7 +160,7 @@ class TestMetadataPartialUpdate: assert len(bindings) == 1 def test_rollback_called_on_commit_failure( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index ba926bf675..8955a3b5f2 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -405,11 +405,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.provider_entities import ProviderEntity - from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( ProviderEntity( @@ -644,9 +643,8 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response - from graphon.model_runtime.entities.common_entities import I18nObject - from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity + from graphon.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py index c146a5924b..5fa5de6d80 100644 --- a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest from models.model import OAuthProviderApp @@ -25,7 +26,7 @@ from services.oauth_server import ( class TestOAuthServerServiceGetProviderApp: """DB-backed tests for get_oauth_provider_app.""" - def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + def _create_oauth_provider_app(self, db_session_with_containers: Session, *, client_id: str) -> OAuthProviderApp: app = OAuthProviderApp( app_icon="icon.png", client_id=client_id, @@ -38,7 +39,7 @@ class TestOAuthServerServiceGetProviderApp: db_session_with_containers.commit() return app - def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers: Session): client_id = f"client-{uuid4()}" created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) @@ -48,7 +49,7 @@ class TestOAuthServerServiceGetProviderApp: assert result.client_id == client_id assert result.id == created.id - def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers: Session): result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_ops_service.py b/api/tests/test_containers_integration_tests/services/test_ops_service.py new file mode 100644 index 0000000000..e2e1a228b2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_ops_service.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import uuid +from unittest.mock import patch + +import pytest +from faker import Faker +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.ops.entities.config_entity import TracingProviderEnum +from models.model import TraceAppConfig +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.ops_service import OpsService +from tests.test_containers_integration_tests.helpers import generate_valid_password + + +class TestOpsService: + @pytest.fixture + def mock_external_service_dependencies(self): + with ( + patch("services.app_service.FeatureService") as mock_feature_service, + patch("services.app_service.EnterpriseService") as mock_enterprise_service, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + mock_feature_service.get_system_features.return_value.webapp_auth.enabled = False + mock_enterprise_service.WebAppAuth.update_app_access_mode.return_value = None + mock_enterprise_service.WebAppAuth.cleanup_webapp.return_value = None + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + mock_model_instance = mock_model_manager.return_value + mock_model_instance.get_default_model_instance.return_value = None + mock_model_instance.get_default_provider_model_name.return_value = ("openai", "gpt-3.5-turbo") + yield { + "feature_service": mock_feature_service, + "enterprise_service": mock_enterprise_service, + "model_manager": mock_model_manager, + "account_feature_service": mock_account_feature_service, + } + + @pytest.fixture + def mock_ops_trace_manager(self): + with patch("services.ops_service.OpsTraceManager") as mock: + yield mock + + def _create_app(self, db_session_with_containers: Session, mock_external_service_dependencies): + fake = Faker() + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🤖", + "icon_background": "#FF6B6B", + }, + account, + ) + return app, account + + _SENTINEL = object() + + def _insert_trace_config( + self, + db_session: Session, + app_id: str, + provider: str, + tracing_config: dict | None | object = _SENTINEL, + ) -> TraceAppConfig: + trace_config = TraceAppConfig( + app_id=app_id, + tracing_provider=provider, + tracing_config=tracing_config if tracing_config is not self._SENTINEL else {"some": "config"}, + ) + db_session.add(trace_config) + db_session.commit() + return trace_config + + # ── get_tracing_app_config ───────────────────────────────────────── + + def test_get_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.get_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_get_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, "arize") + result = OpsService.get_tracing_app_config(fake_app_id, "arize") + assert result is None + + def test_get_tracing_app_config_none_config( + self, db_session_with_containers: Session, mock_external_service_dependencies, mock_ops_trace_manager + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize", tracing_config=None) + + with pytest.raises(ValueError, match="Tracing config cannot be None."): + OpsService.get_tracing_app_config(app.id, "arize") + + @pytest.mark.parametrize( + ("provider", "default_url"), + [ + ("arize", "https://app.arize.com/"), + ("phoenix", "https://app.phoenix.arize.com/projects/"), + ("langsmith", "https://smith.langchain.com/"), + ("opik", "https://www.comet.com/opik/"), + ("weave", "https://wandb.ai/"), + ("aliyun", "https://arms.console.aliyun.com/"), + ("tencent", "https://console.cloud.tencent.com/apm"), + ("mlflow", "http://localhost:5000/"), + ("databricks", "https://www.databricks.com/"), + ], + ) + def test_get_tracing_app_config_providers_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, default_url + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {} + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == default_url + + @pytest.mark.parametrize( + "provider", + ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"], + ) + def test_get_tracing_app_config_providers_success( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.obfuscated_decrypt_token.return_value = {"project_url": "success_url"} + mock_otm.get_trace_config_project_url.return_value = "success_url" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, provider) + + result = OpsService.get_tracing_app_config(app.id, provider) + + assert result is not None + assert result["tracing_config"]["project_url"] == "success_url" + + def test_get_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.return_value = "key" + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" + + def test_get_tracing_app_config_langfuse_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_otm.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "langfuse") + + result = OpsService.get_tracing_app_config(app.id, "langfuse") + + assert result is not None + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" + + # ── create_tracing_app_config ────────────────────────────────────── + + def test_create_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + assert result == {"error": "Invalid tracing provider: invalid_provider"} + + def test_create_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_ops_trace_manager + ): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + result = OpsService.create_tracing_app_config( + str(uuid.uuid4()), TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"} + ) + assert result == {"error": "Invalid Credentials"} + + @pytest.mark.parametrize( + ("provider", "config"), + [ + (TracingProviderEnum.ARIZE, {}), + (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), + (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), + (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), + ], + ) + def test_create_tracing_app_config_project_url_exception( + self, db_session_with_containers: Session, mock_external_service_dependencies, provider, config + ): + # Existing config causes the service to return None before reaching the DB insert + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("error") + mock_otm.get_trace_config_project_key.side_effect = Exception("error") + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(provider)) + + result = OpsService.create_tracing_app_config(app.id, provider, config) + + assert result is None + + def test_create_tracing_app_config_langfuse_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_key.return_value = "key" + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config( + app.id, + TracingProviderEnum.LANGFUSE, + {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"}, + ) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is None + + def test_create_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + result = OpsService.create_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_create_tracing_app_config_with_empty_other_keys( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + # "project" is in other_keys for Arize; providing "" triggers default substitution + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.side_effect = Exception("no url") + mock_otm.encrypt_tracing_config.return_value = {} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {"project": ""}) + + assert result == {"result": "success"} + + def test_create_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.check_trace_config_is_effective.return_value = True + mock_otm.get_trace_config_project_url.return_value = "http://project_url" + mock_otm.encrypt_tracing_config.return_value = {"encrypted": "config"} + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + result = OpsService.create_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result == {"result": "success"} + + # ── update_tracing_app_config ────────────────────────────────────── + + def test_update_tracing_app_config_invalid_provider(self, db_session_with_containers: Session): + with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): + OpsService.update_tracing_app_config(str(uuid.uuid4()), "invalid_provider", {}) + + def test_update_tracing_app_config_no_config(self, db_session_with_containers: Session, mock_ops_trace_manager): + result = OpsService.update_tracing_app_config(str(uuid.uuid4()), TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_no_app(self, db_session_with_containers: Session, mock_ops_trace_manager): + fake_app_id = str(uuid.uuid4()) + self._insert_trace_config(db_session_with_containers, fake_app_id, str(TracingProviderEnum.ARIZE)) + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + result = OpsService.update_tracing_app_config(fake_app_id, TracingProviderEnum.ARIZE, {}) + assert result is None + + def test_update_tracing_app_config_invalid_credentials( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = False + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + with pytest.raises(ValueError, match="Invalid Credentials"): + OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + def test_update_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + with patch("services.ops_service.OpsTraceManager") as mock_otm: + mock_otm.encrypt_tracing_config.return_value = {"updated": "config"} + mock_otm.decrypt_tracing_config.return_value = {} + mock_otm.check_trace_config_is_effective.return_value = True + + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, str(TracingProviderEnum.ARIZE)) + + result = OpsService.update_tracing_app_config(app.id, TracingProviderEnum.ARIZE, {}) + + assert result is not None + assert result["app_id"] == app.id + + # ── delete_tracing_app_config ────────────────────────────────────── + + def test_delete_tracing_app_config_no_config(self, db_session_with_containers: Session): + result = OpsService.delete_tracing_app_config(str(uuid.uuid4()), "arize") + assert result is None + + def test_delete_tracing_app_config_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + app, _ = self._create_app(db_session_with_containers, mock_external_service_dependencies) + self._insert_trace_config(db_session_with_containers, app.id, "arize") + + result = OpsService.delete_tracing_app_config(app.id, "arize") + + assert result is True + remaining = db_session_with_containers.scalar( + select(TraceAppConfig) + .where(TraceAppConfig.app_id == app.id, TraceAppConfig.tracing_provider == "arize") + .limit(1) + ) + assert remaining is None diff --git a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py index 7036524918..2f20949611 100644 --- a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py @@ -8,6 +8,7 @@ from datetime import datetime from uuid import uuid4 from sqlalchemy import select +from sqlalchemy.orm import Session from models.workflow import WorkflowPause, WorkflowRun from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore @@ -39,7 +40,7 @@ class TestWorkflowRunRestore: assert result["created_at"].month == 1 assert result["name"] == "test" - def test_restore_table_records_returns_rowcount(self, db_session_with_containers): + def test_restore_table_records_returns_rowcount(self, db_session_with_containers: Session): """Restore should return inserted rowcount.""" restore = WorkflowRunRestore() record_id = str(uuid4()) @@ -65,7 +66,7 @@ class TestWorkflowRunRestore: restored_pause = db_session_with_containers.scalar(select(WorkflowPause).where(WorkflowPause.id == record_id)) assert restored_pause is not None - def test_restore_table_records_unknown_table(self, db_session_with_containers): + def test_restore_table_records_unknown_table(self, db_session_with_containers: Session): """Unknown table names should be ignored gracefully.""" restore = WorkflowRunRestore() diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 70aa813142..7b9e9924cd 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from models import App, CreatorUserRole from models.enums import ConversationFromSource from models.model import EndUser, Message from models.web import SavedMessage @@ -88,7 +89,7 @@ class TestSavedMessageService: return app, account - def _create_test_end_user(self, db_session_with_containers: Session, app): + def _create_test_end_user(self, db_session_with_containers: Session, app: App): """ Helper method to create a test end user for testing. @@ -116,7 +117,7 @@ class TestSavedMessageService: return end_user - def _create_test_message(self, db_session_with_containers: Session, app, user): + def _create_test_message(self, db_session_with_containers: Session, app: App, user): """ Helper method to create a test message for testing. @@ -199,13 +200,13 @@ class TestSavedMessageService: saved_message1 = SavedMessage( app_id=app.id, message_id=message1.id, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) saved_message2 = SavedMessage( app_id=app.id, message_id=message2.id, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -272,13 +273,13 @@ class TestSavedMessageService: saved_message1 = SavedMessage( app_id=app.id, message_id=message1.id, - created_by_role="end_user", + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) saved_message2 = SavedMessage( app_id=app.id, message_id=message2.id, - created_by_role="end_user", + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) @@ -449,7 +450,7 @@ class TestSavedMessageService: saved_message = SavedMessage( app_id=app.id, message_id=message.id, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -540,7 +541,9 @@ class TestSavedMessageService: message = self._create_test_message(db_session_with_containers, app, account) # Pre-create a saved message - saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id) + saved = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id + ) db_session_with_containers.add(saved) db_session_with_containers.commit() @@ -571,7 +574,9 @@ class TestSavedMessageService: end_user = self._create_test_end_user(db_session_with_containers, app) message = self._create_test_message(db_session_with_containers, app, end_user) - saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id) + saved = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id + ) db_session_with_containers.add(saved) db_session_with_containers.commit() @@ -596,10 +601,10 @@ class TestSavedMessageService: # Both users save the same message saved_account = SavedMessage( - app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account1.id ) saved_end_user = SavedMessage( - app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id ) db_session_with_containers.add_all([saved_account, saved_end_user]) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_schedule_service.py b/api/tests/test_containers_integration_tests/services/test_schedule_service.py new file mode 100644 index 0000000000..87f3306258 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_schedule_service.py @@ -0,0 +1,387 @@ +"""Testcontainers integration tests for schedule service SQL-backed behavior.""" + +from datetime import datetime +from types import SimpleNamespace +from uuid import uuid4 + +import pytest +from sqlalchemy import delete, select +from sqlalchemy.orm import Session + +from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate +from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError +from events.event_handlers.sync_workflow_schedule_when_app_published import sync_schedule_from_workflow +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.trigger import WorkflowSchedulePlan +from services.errors.account import AccountNotFoundError +from services.trigger.schedule_service import ScheduleService + + +class ScheduleServiceIntegrationFactory: + @staticmethod + def create_account_with_tenant( + db_session_with_containers: Session, + role: TenantAccountRole = TenantAccountRole.OWNER, + ) -> tuple[Account, Tenant]: + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_schedule_plan( + db_session_with_containers: Session, + *, + tenant_id: str, + app_id: str | None = None, + node_id: str = "start", + cron_expression: str = "30 10 * * *", + timezone: str = "UTC", + next_run_at: datetime | None = None, + ) -> WorkflowSchedulePlan: + schedule = WorkflowSchedulePlan( + tenant_id=tenant_id, + app_id=app_id or str(uuid4()), + node_id=node_id, + cron_expression=cron_expression, + timezone=timezone, + next_run_at=next_run_at, + ) + db_session_with_containers.add(schedule) + db_session_with_containers.commit() + return schedule + + +def _cron_workflow( + *, + node_id: str = "start", + cron_expression: str = "30 10 * * *", + timezone: str = "UTC", +): + return SimpleNamespace( + graph_dict={ + "nodes": [ + { + "id": node_id, + "data": { + "type": "trigger-schedule", + "mode": "cron", + "cron_expression": cron_expression, + "timezone": timezone, + }, + } + ] + } + ) + + +def _no_schedule_workflow(): + return SimpleNamespace( + graph_dict={ + "nodes": [ + { + "id": "node-1", + "data": {"type": "llm"}, + } + ] + } + ) + + +class TestScheduleServiceIntegration: + def test_create_schedule_persists_schedule(self, db_session_with_containers: Session): + account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + expected_next_run = datetime(2026, 1, 1, 10, 30, 0) + config = ScheduleConfig( + node_id="start", + cron_expression="30 10 * * *", + timezone="UTC", + ) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + schedule = ScheduleService.create_schedule( + session=db_session_with_containers, + tenant_id=tenant.id, + app_id=str(uuid4()), + config=config, + ) + + persisted = db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) + assert persisted is not None + assert persisted.tenant_id == tenant.id + assert persisted.node_id == "start" + assert persisted.cron_expression == "30 10 * * *" + assert persisted.timezone == "UTC" + assert persisted.next_run_at == expected_next_run + + def test_update_schedule_updates_fields_and_recomputes_next_run(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + cron_expression="30 10 * * *", + timezone="UTC", + ) + expected_next_run = datetime(2026, 1, 2, 12, 0, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + updated = ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + updates=SchedulePlanUpdate( + cron_expression="0 12 * * *", + timezone="America/New_York", + ), + ) + + db_session_with_containers.refresh(updated) + assert updated.cron_expression == "0 12 * * *" + assert updated.timezone == "America/New_York" + assert updated.next_run_at == expected_next_run + + def test_update_schedule_updates_only_node_id_without_recomputing_time(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + initial_next_run = datetime(2026, 1, 1, 10, 0, 0) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + next_run_at=initial_next_run, + ) + + with pytest.MonkeyPatch.context() as monkeypatch: + calls: list[tuple] = [] + + def _track(*args, **kwargs): + calls.append((args, kwargs)) + return datetime(2026, 1, 9, 10, 0, 0) + + monkeypatch.setattr("services.trigger.schedule_service.calculate_next_run_at", _track) + updated = ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + db_session_with_containers.refresh(updated) + assert updated.node_id == "node-new" + assert updated.next_run_at == initial_next_run + assert calls == [] + + def test_update_schedule_not_found_raises(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.update_schedule( + session=db_session_with_containers, + schedule_id=str(uuid4()), + updates=SchedulePlanUpdate(node_id="node-new"), + ) + + def test_delete_schedule_removes_row(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + ) + + ScheduleService.delete_schedule( + session=db_session_with_containers, + schedule_id=schedule.id, + ) + db_session_with_containers.commit() + + assert db_session_with_containers.get(WorkflowSchedulePlan, schedule.id) is None + + def test_delete_schedule_not_found_raises(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.delete_schedule( + session=db_session_with_containers, + schedule_id=str(uuid4()), + ) + + def test_get_tenant_owner_returns_owner_account(self, db_session_with_containers: Session): + owner, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.OWNER, + ) + + result = ScheduleService.get_tenant_owner( + session=db_session_with_containers, + tenant_id=tenant.id, + ) + + assert result.id == owner.id + + def test_get_tenant_owner_falls_back_to_admin(self, db_session_with_containers: Session): + admin, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant( + db_session_with_containers, + role=TenantAccountRole.ADMIN, + ) + + result = ScheduleService.get_tenant_owner( + session=db_session_with_containers, + tenant_id=tenant.id, + ) + + assert result.id == admin.id + + def test_get_tenant_owner_raises_when_account_record_missing(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + db_session_with_containers.execute(delete(TenantAccountJoin)) + missing_account_id = str(uuid4()) + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=missing_account_id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + with pytest.raises(AccountNotFoundError, match=missing_account_id): + ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id) + + def test_get_tenant_owner_raises_when_no_owner_or_admin_found(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.commit() + + with pytest.raises(AccountNotFoundError, match=tenant.id): + ScheduleService.get_tenant_owner(session=db_session_with_containers, tenant_id=tenant.id) + + def test_update_next_run_at_updates_persisted_value(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + schedule = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + ) + expected_next_run = datetime(2026, 1, 3, 10, 30, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = ScheduleService.update_next_run_at( + session=db_session_with_containers, + schedule_id=schedule.id, + ) + + db_session_with_containers.refresh(schedule) + assert result == expected_next_run + assert schedule.next_run_at == expected_next_run + + def test_update_next_run_at_raises_when_schedule_not_found(self, db_session_with_containers: Session): + with pytest.raises(ScheduleNotFoundError, match="Schedule not found"): + ScheduleService.update_next_run_at( + session=db_session_with_containers, + schedule_id=str(uuid4()), + ) + + +class TestSyncScheduleFromWorkflowIntegration: + def test_sync_schedule_create_new(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + expected_next_run = datetime(2026, 1, 4, 10, 30, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_cron_workflow(), + ) + + assert result is not None + persisted = db_session_with_containers.execute( + select(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app_id) + ).scalar_one() + assert persisted.node_id == "start" + assert persisted.cron_expression == "30 10 * * *" + assert persisted.timezone == "UTC" + assert persisted.next_run_at == expected_next_run + + def test_sync_schedule_update_existing(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + existing = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + app_id=app_id, + node_id="old-start", + cron_expression="30 10 * * *", + timezone="UTC", + ) + existing_id = existing.id + expected_next_run = datetime(2026, 1, 5, 12, 0, 0) + + with pytest.MonkeyPatch.context() as monkeypatch: + monkeypatch.setattr( + "services.trigger.schedule_service.calculate_next_run_at", + lambda *_args, **_kwargs: expected_next_run, + ) + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_cron_workflow( + node_id="start", + cron_expression="0 12 * * *", + timezone="America/New_York", + ), + ) + + assert result is not None + db_session_with_containers.expire_all() + persisted = db_session_with_containers.get(WorkflowSchedulePlan, existing_id) + assert persisted is not None + assert persisted.node_id == "start" + assert persisted.cron_expression == "0 12 * * *" + assert persisted.timezone == "America/New_York" + assert persisted.next_run_at == expected_next_run + + def test_sync_schedule_remove_when_no_config(self, db_session_with_containers: Session): + _account, tenant = ScheduleServiceIntegrationFactory.create_account_with_tenant(db_session_with_containers) + app_id = str(uuid4()) + existing = ScheduleServiceIntegrationFactory.create_schedule_plan( + db_session_with_containers, + tenant_id=tenant.id, + app_id=app_id, + ) + existing_id = existing.id + + result = sync_schedule_from_workflow( + tenant_id=tenant.id, + app_id=app_id, + workflow=_no_schedule_workflow(), + ) + + assert result is None + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowSchedulePlan, existing_id) is None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 5a6bf0466e..583b6128e6 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -1099,38 +1099,39 @@ class TestTagService: db_session_with_containers, mock_external_service_dependencies ) - # Create tag - tag = self._create_test_tags( - db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 1 - )[0] + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2 + ) - # Create dataset and bind tag + # Create dataset and bind tags dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) self._create_test_tag_bindings( - db_session_with_containers, mock_external_service_dependencies, [tag], dataset.id, tenant.id + db_session_with_containers, mock_external_service_dependencies, tags, dataset.id, tenant.id ) - # Verify binding exists before deletion - - binding_before = ( + # Verify bindings exist before deletion + bindings_before = ( db_session_with_containers.query(TagBinding) - .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) - .first() + .where(TagBinding.tag_id.in_([tag.id for tag in tags]), TagBinding.target_id == dataset.id) + .all() ) - assert binding_before is not None + assert len(bindings_before) == 2 # Act: Execute the method under test - delete_payload = TagBindingDeletePayload(type="knowledge", target_id=dataset.id, tag_id=tag.id) + delete_payload = TagBindingDeletePayload( + type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags] + ) TagService.delete_tag_binding(delete_payload) # Assert: Verify the expected outcomes - # Verify tag binding was deleted - binding_after = ( + # Verify tag bindings were deleted + bindings_after = ( db_session_with_containers.query(TagBinding) - .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) - .first() + .where(TagBinding.tag_id.in_([tag.id for tag in tags]), TagBinding.target_id == dataset.id) + .all() ) - assert binding_after is None + assert len(bindings_after) == 0 def test_delete_tag_binding_non_existent_binding( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1156,7 +1157,7 @@ class TestTagService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Try to delete non-existent binding - delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_id=tag.id) + delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_ids=[tag.id]) TagService.delete_tag_binding(delete_payload) # Assert: Verify the expected outcomes diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index f2307fbd7d..797731d04b 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -6,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom -from models import Account +from models import Account, App from models.enums import ConversationFromSource from models.model import Conversation, EndUser from models.web import PinnedConversation @@ -93,7 +93,7 @@ class TestWebConversationService: return app, account - def _create_test_end_user(self, db_session_with_containers: Session, app): + def _create_test_end_user(self, db_session_with_containers: Session, app: App): """ Helper method to create a test end user for testing. diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index 4fe65d5803..7825f502f7 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -233,11 +233,10 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None - assert result.password is not None - assert result.password_salt is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None + assert refreshed.password is not None + assert refreshed.password_salt is not None def test_authenticate_account_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -414,9 +413,8 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - - db_session_with_containers.refresh(result) - assert result.id is not None + refreshed = db_session_with_containers.get(Account, result.id) + assert refreshed is not None def test_get_user_through_email_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 970da98c55..6d5c7380b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from flask import Flask +from sqlalchemy.orm import Session from werkzeug.datastructures import FileStorage from models.enums import AppTriggerStatus, AppTriggerType @@ -52,7 +53,7 @@ class TestWebhookService: } @pytest.fixture - def test_data(self, db_session_with_containers, mock_external_dependencies): + def test_data(self, db_session_with_containers: Session, mock_external_dependencies): """Create test data for webhook service tests.""" fake = Faker() @@ -160,7 +161,7 @@ class TestWebhookService: "app_trigger": app_trigger, } - def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers): + def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers: Flask): """Test successful retrieval of webhook trigger and workflow.""" webhook_id = test_data["webhook_id"] @@ -175,7 +176,7 @@ class TestWebhookService: assert node_config["id"] == "webhook_node" assert node_config["data"].title == "Test Webhook" - def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers): + def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers: Flask): """Test webhook trigger not found scenario.""" with flask_app_with_containers.app_context(): with pytest.raises(ValueError, match="Webhook not found"): @@ -421,7 +422,9 @@ class TestWebhookService: assert result["files"] == {} - def test_trigger_workflow_execution_success(self, test_data, mock_external_dependencies, flask_app_with_containers): + def test_trigger_workflow_execution_success( + self, test_data, mock_external_dependencies, flask_app_with_containers: Flask + ): """Test successful workflow execution trigger.""" webhook_data = { "method": "POST", @@ -452,7 +455,7 @@ class TestWebhookService: mock_external_dependencies["async_service"].trigger_workflow_async.assert_called_once() def test_trigger_workflow_execution_end_user_service_failure( - self, test_data, mock_external_dependencies, flask_app_with_containers + self, test_data, mock_external_dependencies, flask_app_with_containers: Flask ): """Test workflow execution trigger when EndUserService fails.""" webhook_data = {"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}} diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py new file mode 100644 index 0000000000..69cde847f8 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -0,0 +1,518 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from flask import Flask +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from enums.quota_type import QuotaType +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import AppTriggerStatus, AppTriggerType +from models.model import App +from models.trigger import AppTrigger, WorkflowWebhookTrigger +from models.workflow import Workflow +from services.errors.app import QuotaExceededError +from services.trigger.webhook_service import WebhookService + + +class WebhookServiceRelationshipFactory: + @staticmethod + def create_account_and_tenant(db_session_with_containers: Session) -> tuple[Account, Tenant]: + account = Account( + name=f"Account {uuid4()}", + email=f"webhook-{uuid4()}@example.com", + password="hashed-password", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status="normal") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_app(db_session_with_containers: Session, tenant: Tenant, account: Account) -> App: + app = App( + tenant_id=tenant.id, + name=f"Webhook App {uuid4()}", + description="", + mode="workflow", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + ) + db_session_with_containers.add(app) + db_session_with_containers.commit() + return app + + @staticmethod + def create_workflow( + db_session_with_containers: Session, + *, + app: App, + account: Account, + node_ids: list[str], + version: str, + ) -> Workflow: + graph = { + "nodes": [ + { + "id": node_id, + "data": { + "type": TRIGGER_WEBHOOK_NODE_TYPE, + "title": f"Webhook {node_id}", + "method": "post", + "content_type": "application/json", + "headers": [], + "params": [], + "body": [], + "status_code": 200, + "response_body": '{"status": "ok"}', + "timeout": 30, + }, + } + for node_id in node_ids + ], + "edges": [], + } + + workflow = Workflow( + tenant_id=app.tenant_id, + app_id=app.id, + type="workflow", + graph=json.dumps(graph), + features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + version=version, + ) + db_session_with_containers.add(workflow) + db_session_with_containers.commit() + return workflow + + @staticmethod + def create_webhook_trigger( + db_session_with_containers: Session, + *, + app: App, + account: Account, + node_id: str, + webhook_id: str | None = None, + ) -> WorkflowWebhookTrigger: + webhook_trigger = WorkflowWebhookTrigger( + app_id=app.id, + node_id=node_id, + tenant_id=app.tenant_id, + webhook_id=webhook_id or uuid4().hex[:24], + created_by=account.id, + ) + db_session_with_containers.add(webhook_trigger) + db_session_with_containers.commit() + return webhook_trigger + + @staticmethod + def create_app_trigger( + db_session_with_containers: Session, + *, + app: App, + node_id: str, + status: AppTriggerStatus, + ) -> AppTrigger: + app_trigger = AppTrigger( + tenant_id=app.tenant_id, + app_id=app.id, + node_id=node_id, + trigger_type=AppTriggerType.TRIGGER_WEBHOOK, + provider_name="webhook", + title=f"Webhook {node_id}", + status=status, + ) + db_session_with_containers.add(app_trigger) + db_session_with_containers.commit() + return app_trigger + + +class TestWebhookServiceLookupWithContainers: + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with pytest.raises(ValueError, match="App trigger not found"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.RATE_LIMITED + ) + + with pytest.raises(ValueError, match="rate limited"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.DISABLED + ) + + with pytest.raises(ValueError, match="disabled"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + factory.create_app_trigger( + db_session_with_containers, app=app, node_id="node-1", status=AppTriggerStatus.ENABLED + ) + + with pytest.raises(ValueError, match="Workflow not found"): + WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) + + def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["published-node"], + version="2026-04-14.001", + ) + draft_workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["debug-node"], + version=Workflow.VERSION_DRAFT, + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="debug-node" + ) + + got_trigger, got_workflow, got_node_config = WebhookService.get_webhook_trigger_and_workflow( + webhook_trigger.webhook_id, + is_debug=True, + ) + + assert got_trigger.id == webhook_trigger.id + assert got_workflow.id == draft_workflow.id + assert got_node_config["id"] == "debug-node" + + +class TestWebhookServiceTriggerExecutionWithContainers: + def test_trigger_workflow_execution_triggers_async_workflow_successfully( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + end_user = SimpleNamespace(id=str(uuid4())) + webhook_data = {"body": {"value": 1}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"} + + quota_charge = MagicMock() + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + return_value=end_user, + ), + patch( + "services.trigger.webhook_service.QuotaService.reserve", + return_value=quota_charge, + ) as mock_reserve, + patch("services.trigger.webhook_service.AsyncWorkflowService.trigger_workflow_async") as mock_trigger, + ): + WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow) + + mock_reserve.assert_called_once() + reserve_args = mock_reserve.call_args.args + assert reserve_args[0] == QuotaType.TRIGGER + assert reserve_args[1] == webhook_trigger.tenant_id + quota_charge.commit.assert_called_once() + mock_trigger.assert_called_once() + trigger_args = mock_trigger.call_args.args + assert trigger_args[1] is end_user + assert trigger_args[2].workflow_id == workflow.id + assert trigger_args[2].root_node_id == webhook_trigger.node_id + + def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + return_value=SimpleNamespace(id=str(uuid4())), + ), + patch( + "services.trigger.webhook_service.QuotaService.reserve", + side_effect=QuotaExceededError(feature="trigger", tenant_id=tenant.id, required=1), + ), + patch( + "services.trigger.webhook_service.AppTriggerService.mark_tenant_triggers_rate_limited" + ) as mock_mark_rate_limited, + ): + with pytest.raises(QuotaExceededError): + WebhookService.trigger_workflow_execution( + webhook_trigger, + {"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}, + workflow, + ) + + mock_mark_rate_limited.assert_called_once_with(tenant.id) + + def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version="2026-04-14.001" + ) + webhook_trigger = factory.create_webhook_trigger( + db_session_with_containers, app=app, account=account, node_id="node-1" + ) + + with ( + patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + side_effect=RuntimeError("boom"), + ), + patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + ): + with pytest.raises(RuntimeError, match="boom"): + WebhookService.trigger_workflow_execution( + webhook_trigger, + {"body": {}, "headers": {}, "query_params": {}, "files": {}, "method": "POST"}, + workflow, + ) + + mock_logger_exception.assert_called_once() + + +class TestWebhookServiceRelationshipSyncWithContainers: + def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + node_ids = [f"node-{index}" for index in range(WebhookService.MAX_WEBHOOK_NODES_PER_WORKFLOW + 1)] + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=node_ids, version=Workflow.VERSION_DRAFT + ) + + with pytest.raises(ValueError, match="maximum webhook node limit"): + WebhookService.sync_webhook_relationships(app, workflow) + + def test_sync_webhook_relationships_raises_when_lock_not_acquired( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=["node-1"], version=Workflow.VERSION_DRAFT + ) + lock = MagicMock() + lock.acquire.return_value = False + + with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock): + with pytest.raises(RuntimeError, match="Failed to acquire lock"): + WebhookService.sync_webhook_relationships(app, workflow) + + def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + stale_trigger = factory.create_webhook_trigger( + db_session_with_containers, + app=app, + account=account, + node_id="node-stale", + webhook_id="stale-webhook-id-000001", + ) + stale_trigger_id = stale_trigger.id + workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["node-new"], + version=Workflow.VERSION_DRAFT, + ) + + with patch( + "services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="new-webhook-id-000001" + ): + WebhookService.sync_webhook_relationships(app, workflow) + + db_session_with_containers.expire_all() + records = db_session_with_containers.scalars( + select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id) + ).all() + + assert [record.node_id for record in records] == ["node-new"] + assert records[0].webhook_id == "new-webhook-id-000001" + assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None + + def test_sync_webhook_relationships_sets_redis_cache_for_new_record( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, + app=app, + account=account, + node_ids=["node-cache"], + version=Workflow.VERSION_DRAFT, + ) + cache_key = f"{WebhookService.__WEBHOOK_NODE_CACHE_KEY__}:{app.id}:node-cache" + + with patch( + "services.trigger.webhook_service.WebhookService.generate_webhook_id", return_value="cache-webhook-id-00001" + ): + WebhookService.sync_webhook_relationships(app, workflow) + + cached_payload = WebhookServiceRelationshipFactory._read_cache(cache_key) + assert cached_payload is not None + assert cached_payload["node_id"] == "node-cache" + assert cached_payload["webhook_id"] == "cache-webhook-id-00001" + + def test_sync_webhook_relationships_logs_when_lock_release_fails( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ): + del flask_app_with_containers + factory = WebhookServiceRelationshipFactory + account, tenant = factory.create_account_and_tenant(db_session_with_containers) + app = factory.create_app(db_session_with_containers, tenant, account) + workflow = factory.create_workflow( + db_session_with_containers, app=app, account=account, node_ids=[], version=Workflow.VERSION_DRAFT + ) + lock = MagicMock() + lock.acquire.return_value = True + lock.release.side_effect = RuntimeError("release failed") + + with ( + patch("services.trigger.webhook_service.redis_client.lock", return_value=lock), + patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + ): + WebhookService.sync_webhook_relationships(app, workflow) + + mock_logger_exception.assert_called_once() + + +def _read_cache(cache_key: str) -> dict[str, str] | None: + from extensions.ext_redis import redis_client + + cached = redis_client.get(cache_key) + if not cached: + return None + if isinstance(cached, bytes): + cached = cached.decode("utf-8") + return json.loads(cached) + + +WebhookServiceRelationshipFactory._read_cache = staticmethod(_read_cache) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 749c6fff5b..a2cdddad61 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -8,9 +8,9 @@ from unittest.mock import patch import pytest from faker import Faker -from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session +from graphon.enums import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom @@ -1530,7 +1530,7 @@ class TestWorkflowAppService: assert result_cross_tenant["total"] == 0 def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() @@ -1543,7 +1543,7 @@ class TestWorkflowAppService: ) def test_get_paginate_workflow_app_logs_filters_by_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() @@ -1558,7 +1558,9 @@ class TestWorkflowAppService: assert result["total"] >= 0 assert isinstance(result["data"], list) - def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_workflow_archive_logs( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 0c281c8c33..82fe391b08 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker -from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from graphon.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable @@ -45,7 +45,9 @@ class TestWorkflowDraftVariableService: # WorkflowDraftVariableService doesn't have external dependencies that need mocking return {} - def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None): + def _create_test_app( + self, db_session_with_containers: Session, mock_external_service_dependencies, fake: Faker | None = None + ): """ Helper method to create a test app with realistic data for testing. @@ -80,7 +82,7 @@ class TestWorkflowDraftVariableService: db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, fake: Faker | None = None): """ Helper method to create a test workflow associated with an app. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index b5ce8a53de..9ba1fda08b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -12,7 +12,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from models import Account, App, Workflow +from models import Account, AccountStatus, App, TenantStatus, Workflow from models.model import AppMode from models.workflow import WorkflowType from services.workflow_service import WorkflowService @@ -33,7 +33,7 @@ class TestWorkflowService: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers: Session, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -49,7 +49,7 @@ class TestWorkflowService: email=fake.email(), name=fake.name(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", # Set interface language for Site creation ) account.created_at = fake.date_time_this_year() @@ -62,7 +62,7 @@ class TestWorkflowService: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="normal", + status=TenantStatus.NORMAL, ) tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() @@ -77,7 +77,7 @@ class TestWorkflowService: return account - def _create_test_app(self, db_session_with_containers: Session, fake=None): + def _create_test_app(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test app with realistic data. @@ -109,7 +109,7 @@ class TestWorkflowService: db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake: Faker | None = None): """ Helper method to create a test workflow associated with an app. diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index d3e765055a..af83adaae0 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -1,3 +1,5 @@ +import inspect +import json from unittest.mock import patch import pytest @@ -6,6 +8,8 @@ from pydantic import TypeAdapter, ValidationError from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ApiProviderSchemaType +from core.tools.errors import ApiToolProviderNotFoundError +from core.tools.tool_label_manager import ToolLabelManager from models import Account, Tenant from models.tools import ApiToolProvider from services.tools.api_tools_manage_service import ApiToolManageService @@ -590,30 +594,204 @@ class TestApiToolManageService: with pytest.raises(ValueError, match="you have not added provider"): ApiToolManageService.delete_api_tool_provider(account.id, tenant.id, "nonexistent") - def test_update_api_tool_provider_not_found( + def test_update_api_tool_provider_success( self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): - """Test update raises ValueError when original provider not found.""" fake = Faker() + + # Firmware fix for cache.delete() in update flow + mock_encrypter = mock_external_service_dependencies["encrypter"] + from unittest.mock import MagicMock + + mock_cache = MagicMock() + mock_cache.delete.return_value = None + mock_encrypter.return_value = (mock_encrypter, mock_cache) + + # Get fake account and tenant account, tenant = self._create_test_account_and_tenant( db_session_with_containers, mock_external_service_dependencies ) - with pytest.raises(ValueError, match="does not exists"): - ApiToolManageService.update_api_tool_provider( + # original provider name + original_name = "original-provider" + + # Create original provider + _ = ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=original_name, + icon={"type": "emoji", "value": "🔧"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy="", + custom_disclaimer="", + labels=["old-label"], + ) + + # new provide name and new labels for update + new_name = "updated-provider" + new_labels = ["new-label-1", "new-label-2"] + + # Reset mock history so assertions focus on update path only + mock_external_service_dependencies["encrypter"].reset_mock() + mock_external_service_dependencies["provider_controller"].from_db.reset_mock() + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock() + + # Act: Update the provider with new values + result = ApiToolManageService.update_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + # new provider name - changed 1 + provider_name=new_name, + original_provider=original_name, + # new icon - changed 2 + icon={"type": "emoji", "value": "🚀"}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + # new privacy policy - changed 3 + privacy_policy="https://new-policy.com", + # new custom disclaimer - changed 4 + custom_disclaimer="New disclaimer", + # new labels - changed 5 (However, we will not verify this, not this layer responsibility.) + labels=new_labels, + ) + + # Assert: Verify the result + assert result == {"result": "success"} + + # Get the updated provider from the database + updated_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == new_name) + .first() + ) + + # Verify the provider was updated successfully + assert updated_provider is not None + + # Manually refresh to keep object detachment + db_session_with_containers.refresh(updated_provider) + # Verify all the updated fields + # - changed 1 + assert updated_provider.name == new_name + # - changed 2 + icon_data = json.loads(updated_provider.icon) + assert icon_data["type"] == "emoji" + assert icon_data["value"] == "🚀" + # - changed 3 + assert updated_provider.privacy_policy == "https://new-policy.com" + # - changed 4 + assert updated_provider.custom_disclaimer == "New disclaimer" + + # Verify old provider name no longer exists after rename + original_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == original_name) + .first() + ) + assert original_provider is None + + # Verify update flow calls critical collaborators + mock_external_service_dependencies["provider_controller"].from_db.assert_called_once() + mock_external_service_dependencies["encrypter"].assert_called_once() + mock_cache.delete.assert_called_once() + + # Deeply verify on session propagation of labels update logics: + # Since in refactoring, we pass session down to label manager to keep atomicity. + # The assertion here is to verify this. + sig = inspect.signature(ToolLabelManager.update_tool_labels) + args, kwargs = mock_external_service_dependencies["tool_label_manager"].update_tool_labels.call_args + bound_args = sig.bind(*args, **kwargs) + passed_session = bound_args.arguments.get("session") + # Ensure the type: Session + assert isinstance(passed_session, Session), f"Expected Session object, got {type(passed_session)}" + assert passed_session is not None, ( + "Atomicity Failure: Session cannot be passed to Label Manager in update_api_tool_provider" + ) + + def test_update_api_tool_provider_not_found( + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update raises ValueError when original provider not found. + + This test verifies: + - Proper error when trying to update a non-existing original provider + - No accidental upsert/new provider creation + - No external dependency invocation on early failure path + """ + # Arrange: Create test account and tenant + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + # Keep an existing provider in DB to ensure unrelated data remains unchanged + existing_provider_name = "existing-provider" + _ = ApiToolManageService.create_api_tool_provider( + user_id=account.id, + tenant_id=tenant.id, + provider_name=existing_provider_name, + icon={"type": "emoji", "value": "🔧"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema=self._create_test_openapi_schema(), + privacy_policy="https://existing-policy.com", + custom_disclaimer="Existing disclaimer", + labels=["existing-label"], + ) + + # Reset mock history so assertions focus on update failure path only + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.reset_mock() + mock_external_service_dependencies["encrypter"].reset_mock() + mock_external_service_dependencies["provider_controller"].from_db.reset_mock() + + # Act & Assert: Verify update fails with clear error message + target_new_name = "new-provider-name" + missing_original_name = "missing-original-provider" + with pytest.raises(ApiToolProviderNotFoundError) as exc_info: + _ = ApiToolManageService.update_api_tool_provider( user_id=account.id, tenant_id=tenant.id, - provider_name="new-name", - original_provider="nonexistent", - icon={}, + provider_name=target_new_name, + original_provider=missing_original_name, + icon={"type": "emoji", "value": "🚀"}, credentials={"auth_type": "none"}, _schema_type=ApiProviderSchemaType.OPENAPI, schema=self._create_test_openapi_schema(), - privacy_policy=None, - custom_disclaimer="", - labels=[], + privacy_policy="https://new-policy.com", + custom_disclaimer="New disclaimer", + labels=["new-label"], ) + error = exc_info.value + assert error.provider_name == missing_original_name + assert error.tenant_id == tenant.id + assert error.error_code == "api_tool_provider_not_found" + + # Assert: Existing provider should remain unchanged + existing_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == existing_provider_name) + .first() + ) + assert existing_provider is not None + assert existing_provider.name == existing_provider_name + + # Assert: No new provider should be created + unexpected_new_provider: ApiToolProvider | None = ( + db_session_with_containers.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == target_new_name) + .first() + ) + assert unexpected_new_provider is None + + # Assert: Early failure should skip all downstream external interactions + mock_external_service_dependencies["tool_label_manager"].update_tool_labels.assert_not_called() + mock_external_service_dependencies["encrypter"].assert_not_called() + mock_external_service_dependencies["provider_controller"].from_db.assert_not_called() + def test_update_api_tool_provider_missing_auth_type( self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index ce2fd2eeb1..ce5c2bd162 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -5,9 +5,6 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import PromptMessageRole -from graphon.variables.input_entities import VariableEntity, VariableEntityType from sqlalchemy.orm import Session from core.app.app_config.entities import ( @@ -21,6 +18,9 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py index 29e1e240b4..afc4908c15 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py @@ -100,7 +100,7 @@ class TestWorkflowDeletion: session.flush() return provider - def test_delete_workflow_success(self, db_session_with_containers): + def test_delete_workflow_success(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -118,7 +118,7 @@ class TestWorkflowDeletion: db_session_with_containers.expire_all() assert db_session_with_containers.get(Workflow, workflow_id) is None - def test_delete_draft_workflow_raises_error(self, db_session_with_containers): + def test_delete_draft_workflow_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -130,7 +130,7 @@ class TestWorkflowDeletion: with pytest.raises(DraftWorkflowDeletionError): service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) - def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers): + def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -144,7 +144,7 @@ class TestWorkflowDeletion: with pytest.raises(WorkflowInUseError, match="currently in use by app"): service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) - def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers): + def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 7c43bf676b..32b76c3469 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 -from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel @@ -64,7 +64,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: db_session_with_containers.commit() return execution - def test_get_node_last_execution_found(self, db_session_with_containers): + def test_get_node_last_execution_found(self, db_session_with_containers: Session): """Test getting the last execution for a node when it exists.""" # Arrange tenant_id = str(uuid4()) @@ -110,7 +110,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result.id == expected.id assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - def test_get_node_last_execution_not_found(self, db_session_with_containers): + def test_get_node_last_execution_not_found(self, db_session_with_containers: Session): """Test getting the last execution for a node when it doesn't exist.""" # Arrange tenant_id = str(uuid4()) @@ -129,7 +129,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result is None - def test_get_executions_by_workflow_run_empty(self, db_session_with_containers): + def test_get_executions_by_workflow_run_empty(self, db_session_with_containers: Session): """Test getting executions for a workflow run when none exist.""" # Arrange tenant_id = str(uuid4()) @@ -147,7 +147,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == [] - def test_get_execution_by_id_found(self, db_session_with_containers): + def test_get_execution_by_id_found(self, db_session_with_containers: Session): """Test getting execution by ID when it exists.""" # Arrange execution = self._create_execution( @@ -170,7 +170,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result is not None assert result.id == execution.id - def test_get_execution_by_id_not_found(self, db_session_with_containers): + def test_get_execution_by_id_not_found(self, db_session_with_containers: Session): """Test getting execution by ID when it doesn't exist.""" # Arrange repository = self._create_repository(db_session_with_containers) @@ -182,7 +182,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result is None - def test_delete_expired_executions(self, db_session_with_containers): + def test_delete_expired_executions(self, db_session_with_containers: Session): """Test deleting expired executions.""" # Arrange tenant_id = str(uuid4()) @@ -248,7 +248,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert old_execution_2_id not in remaining_ids assert kept_execution_id in remaining_ids - def test_delete_executions_by_app(self, db_session_with_containers): + def test_delete_executions_by_app(self, db_session_with_containers: Session): """Test deleting executions by app.""" # Arrange tenant_id = str(uuid4()) @@ -313,7 +313,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert deleted_2_id not in remaining_ids assert kept_id in remaining_ids - def test_get_expired_executions_batch(self, db_session_with_containers): + def test_get_expired_executions_batch(self, db_session_with_containers: Session): """Test getting expired executions batch for backup.""" # Arrange tenant_id = str(uuid4()) @@ -370,7 +370,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert old_execution_1.id in result_ids assert old_execution_2.id in result_ids - def test_delete_executions_by_ids(self, db_session_with_containers): + def test_delete_executions_by_ids(self, db_session_with_containers: Session): """Test deleting executions by IDs.""" # Arrange tenant_id = str(uuid4()) @@ -424,7 +424,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: ).all() assert remaining == [] - def test_delete_executions_by_ids_empty_list(self, db_session_with_containers): + def test_delete_executions_by_ids_empty_list(self, db_session_with_containers: Session): """Test deleting executions with empty ID list.""" # Arrange repository = self._create_repository(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 4b04c1accb..fcc15aad42 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -530,22 +531,18 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Verify logs exist before processing - existing_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + existing_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(existing_logs) == 2 # Act: Execute the task add_document_to_index_task(document.id) # Assert: Verify auto disable logs were deleted - remaining_logs = ( - db_session_with_containers.query(DatasetAutoDisableLog) - .where(DatasetAutoDisableLog.document_id == document.id) - .all() - ) + remaining_logs = db_session_with_containers.scalars( + select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id) + ).all() assert len(remaining_logs) == 0 # Verify index processing occurred normally diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index 6cbbe43137..e29ca7ebab 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -11,6 +11,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import func, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType @@ -267,11 +268,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Ensure all changes are committed # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_with_image_files( @@ -319,7 +322,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Verify that the task completed successfully by checking the log output @@ -360,14 +365,14 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify database cleanup db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_dataset_not_found( @@ -410,7 +415,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Document should still exist since cleanup failed - existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first() + existing_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document_id).limit(1) + ) assert existing_document is not None def test_batch_clean_document_task_storage_cleanup_failure( @@ -453,11 +460,13 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted from database - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted from database - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None def test_batch_clean_document_task_multiple_documents( @@ -510,12 +519,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_different_doc_forms( @@ -564,7 +577,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None except Exception as e: @@ -574,7 +589,9 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Check if the segment still exists (task may have failed before deletion) - existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + existing_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) if existing_segment is not None: # If segment still exists, the task failed before deletion # This is acceptable in test environments with external service issues @@ -645,12 +662,16 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) + ) assert deleted_file is None def test_batch_clean_document_task_integration_with_real_database( @@ -699,8 +720,16 @@ class TestBatchCleanDocumentTask: db_session_with_containers.commit() # Verify initial state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 - assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) + ) + == 3 + ) + assert ( + db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == upload_file.id).limit(1)) + is not None + ) # Store original IDs for verification document_id = document.id @@ -720,13 +749,20 @@ class TestBatchCleanDocumentTask: # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1) + ) assert deleted_segment is None # Check that upload file is deleted - deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) assert deleted_file is None # Verify final database state - assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 - assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None + assert ( + db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document_id) + ) + == 0 + ) + assert db_session_with_containers.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index f9ae33b32f..05827112d4 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -37,13 +38,13 @@ class TestBatchCreateSegmentToIndexTask: from extensions.ext_redis import redis_client # Clear all test data - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -292,12 +293,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that segments were created - segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(segments) == 3 # Verify segment content and metadata @@ -367,11 +365,11 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created (since dataset doesn't exist) - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify no documents were modified - documents = db_session_with_containers.query(Document).all() + documents = db_session_with_containers.scalars(select(Document)).all() assert len(documents) == 0 def test_batch_create_segment_to_index_task_document_not_found( @@ -415,12 +413,14 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify dataset remains unchanged (no segments were added to the dataset) db_session_with_containers.refresh(dataset) - segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments_for_dataset = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments_for_dataset) == 0 def test_batch_create_segment_to_index_task_document_not_available( @@ -516,7 +516,9 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id) + ).all() assert len(segments) == 0 def test_batch_create_segment_to_index_task_upload_file_not_found( @@ -560,7 +562,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify no segments were created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -611,7 +613,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify error handling # Since exception was raised, no segments should be created - segments = db_session_with_containers.query(DocumentSegment).all() + segments = db_session_with_containers.scalars(select(DocumentSegment)).all() assert len(segments) == 0 # Verify document remains unchanged @@ -682,12 +684,9 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that new segments were created with correct positions - all_segments = ( - db_session_with_containers.query(DocumentSegment) - .filter_by(document_id=document.id) - .order_by(DocumentSegment.position) - .all() - ) + all_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == document.id).order_by(DocumentSegment.position) + ).all() assert len(all_segments) == 6 # 3 existing + 3 new # Verify position ordering diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index 1dd37fbc92..32bc2fc0bd 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -52,18 +53,18 @@ class TestCleanDatasetTask: from extensions.ext_redis import redis_client # Clear all test data using the provided session fixture - db_session_with_containers.query(DatasetMetadataBinding).delete() - db_session_with_containers.query(DatasetMetadata).delete() - db_session_with_containers.query(AppDatasetJoin).delete() - db_session_with_containers.query(DatasetQuery).delete() - db_session_with_containers.query(DatasetProcessRule).delete() - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DatasetMetadataBinding)) + db_session_with_containers.execute(delete(DatasetMetadata)) + db_session_with_containers.execute(delete(AppDatasetJoin)) + db_session_with_containers.execute(delete(DatasetQuery)) + db_session_with_containers.execute(delete(DatasetProcessRule)) + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -302,28 +303,40 @@ class TestCleanDatasetTask: # Verify results # Check that dataset-related data was cleaned up - documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + documents = db_session_with_containers.scalars(select(Document).where(Document.dataset_id == dataset.id)).all() assert len(documents) == 0 - segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(segments) == 0 # Check that metadata and bindings were cleaned up - metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(metadata) == 0 - bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() + bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(bindings) == 0 # Check that process rules and queries were cleaned up - process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all() + process_rules = db_session_with_containers.scalars( + select(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset.id) + ).all() assert len(process_rules) == 0 - queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all() + queries = db_session_with_containers.scalars( + select(DatasetQuery).where(DatasetQuery.dataset_id == dataset.id) + ).all() assert len(queries) == 0 # Check that app dataset joins were cleaned up - app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all() + app_joins = db_session_with_containers.scalars( + select(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset.id) + ).all() assert len(app_joins) == 0 # Verify index processor was called @@ -414,24 +427,32 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ).all() assert len(remaining_files) == 0 # Check that metadata and bindings were cleaned up - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 - remaining_bindings = ( - db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() - ) + remaining_bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(remaining_bindings) == 0 # Verify index processor was called @@ -485,12 +506,14 @@ class TestCleanDatasetTask: # Check that all data was cleaned up - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 - remaining_segments = ( - db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() - ) + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Recreate data for next test case @@ -538,11 +561,15 @@ class TestCleanDatasetTask: # Verify results - even with vector cleanup failure, documents and segments should be deleted # Check that documents were still deleted despite vector cleanup failure - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that segments were still deleted despite vector cleanup failure - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Verify that index processor was called and failed @@ -622,18 +649,22 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all image files were deleted from database image_file_ids = [f.id for f in image_files] - remaining_image_files = ( - db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all() - ) + remaining_image_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(image_file_ids)) + ).all() assert len(remaining_image_files) == 0 # Verify that storage.delete was called for each image file @@ -738,24 +769,32 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ).all() assert len(remaining_files) == 0 # Check that all metadata and bindings were deleted - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 - remaining_bindings = ( - db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all() - ) + remaining_bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset.id) + ).all() assert len(remaining_bindings) == 0 # Verify performance expectations @@ -826,7 +865,9 @@ class TestCleanDatasetTask: # Check that upload file was still deleted from database despite storage failure # Note: When storage operations fail, the upload file may not be deleted # This demonstrates that the cleanup process continues even with storage errors - remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id == upload_file.id) + ).all() # The upload file should still be deleted from the database even if storage cleanup fails # However, this depends on the specific implementation of clean_dataset_task if len(remaining_files) > 0: @@ -976,19 +1017,27 @@ class TestCleanDatasetTask: # Verify results # Check that all documents were deleted - remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all() + remaining_documents = db_session_with_containers.scalars( + select(Document).where(Document.dataset_id == dataset.id) + ).all() assert len(remaining_documents) == 0 # Check that all segments were deleted - remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + remaining_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.dataset_id == dataset.id) + ).all() assert len(remaining_segments) == 0 # Check that all upload files were deleted - remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all() + remaining_files = db_session_with_containers.scalars( + select(UploadFile).where(UploadFile.id == upload_file_id) + ).all() assert len(remaining_files) == 0 # Check that all metadata was deleted - remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all() + remaining_metadata = db_session_with_containers.scalars( + select(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset.id) + ).all() assert len(remaining_metadata) == 0 # Verify that storage.delete was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 926c839c8b..1c8d5969e0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -11,6 +11,8 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy import ColumnElement, func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -20,6 +22,14 @@ from tasks.clean_notion_document_task import clean_notion_document_task from tests.test_containers_integration_tests.helpers import generate_valid_password +def _count_documents(session: Session, condition: ColumnElement[bool]) -> int: + return session.scalar(select(func.count()).select_from(Document).where(condition)) or 0 + + +def _count_segments(session: Session, condition: ColumnElement[bool]) -> int: + return session.scalar(select(func.count()).select_from(DocumentSegment).where(condition)) or 0 + + class TestCleanNotionDocumentTask: """Integration tests for clean_notion_document_task using testcontainers.""" @@ -61,7 +71,7 @@ class TestCleanNotionDocumentTask: yield mock_factory def test_clean_notion_document_task_success( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test successful cleanup of Notion documents with proper database operations. @@ -145,24 +155,14 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 3 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(document_ids)) - .count() - == 6 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(document_ids)) == 3 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 6 # Execute cleanup task clean_notion_document_task(document_ids, dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(document_ids)) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(document_ids)) == 0 # Verify index processor was called mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value @@ -176,7 +176,7 @@ class TestCleanNotionDocumentTask: # 5. The task completes without errors def test_clean_notion_document_task_dataset_not_found( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task behavior when dataset is not found. @@ -196,7 +196,7 @@ class TestCleanNotionDocumentTask: mock_index_processor_factory.return_value.init_index_processor.assert_not_called() def test_clean_notion_document_task_empty_document_list( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task behavior with empty document list. @@ -240,7 +240,7 @@ class TestCleanNotionDocumentTask: assert args[1] == [] def test_clean_notion_document_task_with_different_index_types( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with different dataset index types. @@ -322,18 +322,13 @@ class TestCleanNotionDocumentTask: # The task properly handles various index types and document configurations. # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id == document.id) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Reset mock for next iteration mock_index_processor_factory.reset_mock() def test_clean_notion_document_task_with_segments_no_index_node_ids( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with segments that have no index_node_ids. @@ -410,16 +405,13 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies that segments without index_node_ids # are properly deleted from the database. def test_clean_notion_document_task_partial_document_cleanup( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with partial document cleanup scenario. @@ -499,11 +491,8 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 5 - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 10 - ) + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == 5 + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 10 # Clean up only first 3 documents documents_to_clean = [doc.id for doc in documents[:3]] @@ -513,28 +502,18 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(documents_to_clean, dataset.id) # Verify only specified documents' segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(documents_to_clean)) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(documents_to_clean)) == 0 # Verify remaining documents and segments are intact remaining_docs = [doc.id for doc in documents[3:]] - assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(remaining_docs)) - .count() - == 4 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 4 # Note: This test successfully verifies partial document cleanup operations. # The database operations work correctly, isolating only the specified documents. def test_clean_notion_document_task_with_mixed_segment_statuses( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with segments in different statuses. @@ -612,31 +591,36 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all segments exist before cleanup - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 4 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 4 # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify all segments are deleted regardless of status - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies database operations. # IndexProcessor verification would require more sophisticated mocking. - def test_clean_notion_document_task_database_transaction_rollback( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + def test_clean_notion_document_task_continues_when_index_processor_fails( + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ - Test cleanup task behavior when database operations fail. + Index processor failure (e.g. transient billing API error propagated via + ``FeatureService`` when ``Vector(dataset)`` lazily resolves the embedding + model) must NOT abort the cleanup task. The Document rows have already + been hard-deleted in the first session block before vector cleanup runs, + so any uncaught exception escaping the task would strand + ``DocumentSegment`` rows in PG with no parent ``Document``. - This test verifies that the task properly handles database errors - and maintains data consistency. + Contract: the task swallows the index_processor exception, logs it, and + proceeds to delete the segments — leaving PG consistent. (Vector orphans, + if any, can be reaped later by an offline scanner.) + + Regression guard for the production incident where ``clean_document_task`` + / ``clean_notion_document_task`` failed with + ``ValueError("Unable to retrieve billing information...")`` and left + tens of thousands of orphan segments per affected tenant. """ fake = Faker() @@ -699,20 +683,31 @@ class TestCleanNotionDocumentTask: db_session_with_containers.add(segment) db_session_with_containers.commit() - # Mock index processor to raise an exception + # Simulate the production failure mode: index_processor.clean() raises a + # ValueError mirroring ``BillingService._send_request`` returning non-200. mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value - mock_index_processor.clean.side_effect = Exception("Index processor error") + mock_index_processor.clean.side_effect = ValueError( + "Unable to retrieve billing information. Please try again later or contact support." + ) - # Execute cleanup task - current implementation propagates the exception - with pytest.raises(Exception, match="Index processor error"): - clean_notion_document_task([document.id], dataset.id) + # Execute cleanup task — must NOT raise even though clean() raises. + # Before the safety-net wrapper this would have re-raised the ValueError, + # aborting the task and leaving DocumentSegment stranded in PG. + clean_notion_document_task([document.id], dataset.id) - # Note: This test demonstrates the task's error handling capability. - # Even with external service errors, the database operations complete successfully. - # In a production environment, proper error handling would determine transaction rollback behavior. + # Vector cleanup was attempted exactly once. + mock_index_processor.clean.assert_called_once() + + # The crucial assertion: despite the index processor failure, the + # final session block (line 51-52, ``DELETE FROM document_segments``) + # still ran and committed. This is what the wrapper buys us — without + # it the production incident left tens of thousands of orphan segments + # per affected tenant. Aligns with the assertion shape used by the + # happy-path test (``test_clean_notion_document_task_success``). + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 def test_clean_notion_document_task_with_large_number_of_documents( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with a large number of documents and segments. @@ -794,12 +789,9 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == num_documents assert ( - db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() - == num_documents - ) - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == num_documents * num_segments_per_doc ) @@ -808,16 +800,13 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 # Note: This test successfully verifies bulk document cleanup operations. # The database efficiently handles large-scale deletions. def test_clean_notion_document_task_with_documents_from_different_tenants( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents from different tenants. @@ -906,8 +895,8 @@ class TestCleanNotionDocumentTask: # Verify all data exists before cleanup # Note: There may be documents from previous tests, so we check for at least 3 - assert db_session_with_containers.query(Document).count() >= 3 - assert db_session_with_containers.query(DocumentSegment).count() >= 9 + assert db_session_with_containers.scalar(select(func.count()).select_from(Document)) >= 3 + assert db_session_with_containers.scalar(select(func.count()).select_from(DocumentSegment)) >= 9 # Clean up documents from only the first dataset target_dataset = datasets[0] @@ -918,28 +907,18 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([target_document.id], target_dataset.id) # Verify only documents' segments from target dataset are deleted - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id == target_document.id) - .count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == target_document.id) == 0 # Verify documents from other datasets remain intact remaining_docs = [doc.id for doc in all_documents[1:]] - assert db_session_with_containers.query(Document).filter(Document.id.in_(remaining_docs)).count() == 2 - assert ( - db_session_with_containers.query(DocumentSegment) - .filter(DocumentSegment.document_id.in_(remaining_docs)) - .count() - == 6 - ) + assert _count_documents(db_session_with_containers, Document.id.in_(remaining_docs)) == 2 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id.in_(remaining_docs)) == 6 # Note: This test successfully verifies multi-tenant isolation. # Only documents from the target dataset are affected, maintaining tenant separation. def test_clean_notion_document_task_with_documents_in_different_states( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents in different indexing states. @@ -1028,11 +1007,9 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify all data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == len( - document_statuses - ) + assert _count_documents(db_session_with_containers, Document.dataset_id == dataset.id) == len(document_statuses) assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() + _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == len(document_statuses) * 2 ) @@ -1041,16 +1018,13 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(all_document_ids, dataset.id) # Verify all segments are deleted regardless of status - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.dataset_id == dataset.id) == 0 # Note: This test successfully verifies cleanup of documents in various states. # All documents are deleted regardless of their indexing status. def test_clean_notion_document_task_with_documents_having_metadata( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents that have rich metadata. @@ -1142,20 +1116,14 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Verify data exists before cleanup - assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 1 - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 3 - ) + assert _count_documents(db_session_with_containers, Document.id == document.id) == 1 + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 3 # Execute cleanup task clean_notion_document_task([document.id], dataset.id) # Verify segments are deleted - assert ( - db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() - == 0 - ) + assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 # Note: This test successfully verifies cleanup of documents with rich metadata. # The task properly handles complex document structures and metadata fields. diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 9f8e37fc9e..80289c448a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -11,6 +11,8 @@ from uuid import uuid4 import pytest from faker import Faker +from sqlalchemy import delete +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client @@ -24,16 +26,16 @@ class TestCreateSegmentToIndexTask: """Integration tests for create_segment_to_index_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database and Redis before each test to ensure isolation.""" # Clear all test data using fixture session - db_session_with_containers.query(DocumentSegment).delete() - db_session_with_containers.query(Document).delete() - db_session_with_containers.query(Dataset).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(DocumentSegment)) + db_session_with_containers.execute(delete(Document)) + db_session_with_containers.execute(delete(Dataset)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -54,7 +56,7 @@ class TestCreateSegmentToIndexTask: "index_processor": mock_processor, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -101,7 +103,7 @@ class TestCreateSegmentToIndexTask: return account, tenant - def _create_test_dataset_and_document(self, db_session_with_containers, tenant_id, account_id): + def _create_test_dataset_and_document(self, db_session_with_containers: Session, tenant_id, account_id): """ Helper method to create a test dataset and document for testing. @@ -150,7 +152,13 @@ class TestCreateSegmentToIndexTask: return dataset, document def _create_test_segment( - self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING + self, + db_session_with_containers: Session, + dataset_id, + document_id, + tenant_id, + account_id, + status=SegmentStatus.WAITING, ): """ Helper method to create a test document segment for testing. @@ -188,7 +196,9 @@ class TestCreateSegmentToIndexTask: return segment - def test_create_segment_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful creation of segment to index. @@ -224,7 +234,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 0 def test_create_segment_to_index_segment_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent segment ID. @@ -245,7 +255,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_invalid_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with invalid status. @@ -276,7 +286,9 @@ class TestCreateSegmentToIndexTask: # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() - def test_create_segment_to_index_no_dataset(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_no_dataset( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test handling of segment without associated dataset. @@ -329,7 +341,9 @@ class TestCreateSegmentToIndexTask: # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() - def test_create_segment_to_index_no_document(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_no_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test handling of segment without associated document. @@ -366,7 +380,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with disabled document. @@ -402,7 +416,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_archived( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with archived document. @@ -438,7 +452,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_indexing_incomplete( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with document that has incomplete indexing. @@ -474,7 +488,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_processor_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of index processor exceptions. @@ -510,7 +524,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 0 def test_create_segment_to_index_with_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with custom keywords. @@ -542,7 +556,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_different_doc_forms( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with different document forms. @@ -585,7 +599,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) def test_create_segment_to_index_performance_timing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing performance and timing. @@ -616,7 +630,7 @@ class TestCreateSegmentToIndexTask: assert segment.status == SegmentStatus.COMPLETED def test_create_segment_to_index_concurrent_execution( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test concurrent execution of segment indexing tasks. @@ -653,7 +667,7 @@ class TestCreateSegmentToIndexTask: assert mock_external_service_dependencies["index_processor_factory"].call_count == 3 def test_create_segment_to_index_large_content( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with large content. @@ -702,7 +716,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_redis_failure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing when Redis operations fail. @@ -742,7 +756,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 1 def test_create_segment_to_index_database_transaction_rollback( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with database transaction handling. @@ -774,7 +788,7 @@ class TestCreateSegmentToIndexTask: assert segment.error is not None def test_create_segment_to_index_metadata_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with metadata validation. @@ -816,7 +830,7 @@ class TestCreateSegmentToIndexTask: assert doc is not None def test_create_segment_to_index_status_transition_flow( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test complete status transition flow during indexing. @@ -851,7 +865,7 @@ class TestCreateSegmentToIndexTask: assert segment.indexing_at <= segment.completed_at def test_create_segment_to_index_with_empty_content( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with empty or minimal content. @@ -893,7 +907,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_with_special_characters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with special characters and unicode content. @@ -939,7 +953,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_with_long_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with long keyword lists. @@ -973,7 +987,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with proper tenant isolation. @@ -1016,7 +1030,7 @@ class TestCreateSegmentToIndexTask: assert segment1.tenant_id != segment2.tenant_id def test_create_segment_to_index_with_none_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with None keywords parameter. @@ -1047,7 +1061,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_comprehensive_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Comprehensive integration test covering multiple scenarios. diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 13ea94348a..a5a3cd10b5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -6,6 +6,8 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -173,11 +175,11 @@ class TestDatasetIndexingTaskIntegration: return dataset, documents - def _query_document(self, db_session_with_containers, document_id: str) -> Document | None: + def _query_document(self, db_session_with_containers: Session, document_id: str) -> Document | None: """Return the latest persisted document state.""" - return db_session_with_containers.query(Document).where(Document.id == document_id).first() + return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1)) - def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None: + def _assert_documents_parsing(self, db_session_with_containers: Session, document_ids: Sequence[str]) -> None: """Assert all target documents are persisted in parsing status.""" db_session_with_containers.expire_all() for document_id in document_ids: @@ -211,7 +213,9 @@ class TestDatasetIndexingTaskIntegration: assert len(opened) >= 2 assert opened_ids <= closed_ids - def test_legacy_document_indexing_task_still_works(self, db_session_with_containers, patched_external_dependencies): + def test_legacy_document_indexing_task_still_works( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Ensure the legacy task entrypoint still updates parsing status.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=2) @@ -224,7 +228,9 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_batch_processing_multiple_documents(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_multiple_documents( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Process multiple documents in one batch.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=3) @@ -239,7 +245,9 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == len(document_ids) self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_batch_processing_with_limit_check(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_with_limit_check( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Reject batches larger than configured upload limit. This test patches config only to force a deterministic limit branch while keeping SQL writes real. @@ -262,7 +270,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_error_contains(db_session_with_containers, document_ids, "batch upload limit") def test_batch_processing_sandbox_plan_single_document_only( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Reject multi-document upload under sandbox plan.""" # Arrange @@ -279,7 +287,9 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_not_called() self._assert_documents_error_contains(db_session_with_containers, document_ids, "does not support batch upload") - def test_batch_processing_empty_document_list(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_empty_document_list( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Handle empty list input without failing.""" # Arrange dataset, _ = self._create_test_dataset_and_documents(db_session_with_containers, document_count=0) @@ -291,7 +301,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) def test_tenant_queue_dispatches_next_task_after_completion( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Dispatch the next queued task after current tenant task completes. @@ -336,7 +346,7 @@ class TestDatasetIndexingTaskIntegration: delete_key_spy.assert_not_called() def test_tenant_queue_deletes_running_key_when_no_follow_up_tasks( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Delete tenant running flag when queue has no pending tasks. @@ -361,7 +371,7 @@ class TestDatasetIndexingTaskIntegration: delete_key_spy.assert_called_once() def test_validation_failure_sets_error_status_when_vector_space_at_limit( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Set error status when vector space validation fails before runner phase.""" # Arrange @@ -381,7 +391,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_error_contains(db_session_with_containers, document_ids, "over the limit") def test_runner_exception_does_not_crash_indexing_task( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Catch generic runner exceptions without crashing the task.""" # Arrange @@ -396,7 +406,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_document_paused_error_handling(self, db_session_with_containers, patched_external_dependencies): + def test_document_paused_error_handling(self, db_session_with_containers: Session, patched_external_dependencies): """Handle DocumentIsPausedError and keep persisted state consistent.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=2) @@ -423,7 +433,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_not_called() def test_tenant_queue_error_handling_still_processes_next_task( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Even on current task failure, enqueue the next waiting tenant task. @@ -490,7 +500,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_all_opened_sessions_closed(session_close_tracker) def test_multiple_documents_with_mixed_success_and_failure( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Process only existing documents when request includes missing ids.""" # Arrange @@ -507,7 +517,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, existing_ids) def test_tenant_queue_dispatches_up_to_concurrency_limit( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Dispatch only up to configured concurrency under queued backlog burst. @@ -542,7 +552,7 @@ class TestDatasetIndexingTaskIntegration: assert task_dispatch_spy.apply_async.call_count == concurrency_limit assert set_waiting_spy.call_count == concurrency_limit - def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies): + def test_task_queue_fifo_ordering(self, db_session_with_containers: Session, patched_external_dependencies): """Keep FIFO ordering when dispatching next queued tasks. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -575,7 +585,9 @@ class TestDatasetIndexingTaskIntegration: call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {}) assert call_kwargs.get("document_ids") == expected_task["document_ids"] - def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies): + def test_billing_disabled_skips_limit_checks( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Skip limit checks when billing feature is disabled.""" # Arrange large_document_ids = [str(uuid.uuid4()) for _ in range(100)] @@ -594,7 +606,7 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == 100 self._assert_documents_parsing(db_session_with_containers, large_document_ids) - def test_complete_workflow_normal_task(self, db_session_with_containers, patched_external_dependencies): + def test_complete_workflow_normal_task(self, db_session_with_containers: Session, patched_external_dependencies): """Run end-to-end normal queue workflow with tenant queue cleanup. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -617,7 +629,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) delete_key_spy.assert_called_once() - def test_complete_workflow_priority_task(self, db_session_with_containers, patched_external_dependencies): + def test_complete_workflow_priority_task(self, db_session_with_containers: Session, patched_external_dependencies): """Run end-to-end priority queue workflow with tenant queue cleanup. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -640,7 +652,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) delete_key_spy.assert_called_once() - def test_single_document_processing(self, db_session_with_containers, patched_external_dependencies): + def test_single_document_processing(self, db_session_with_containers: Session, patched_external_dependencies): """Process the minimum batch size (single document).""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=1) @@ -654,7 +666,9 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == 1 self._assert_documents_parsing(db_session_with_containers, [document_id]) - def test_document_with_special_characters_in_id(self, db_session_with_containers, patched_external_dependencies): + def test_document_with_special_characters_in_id( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Handle standard UUID ids with hyphen characters safely.""" # Arrange special_document_id = str(uuid.uuid4()) @@ -669,7 +683,9 @@ class TestDatasetIndexingTaskIntegration: # Assert self._assert_documents_parsing(db_session_with_containers, [special_document_id]) - def test_zero_vector_space_limit_allows_unlimited(self, db_session_with_containers, patched_external_dependencies): + def test_zero_vector_space_limit_allows_unlimited( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Treat vector limit 0 as unlimited and continue indexing.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=3) @@ -688,7 +704,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) def test_negative_vector_space_values_handled_gracefully( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Treat negative vector limits as non-blocking and continue indexing.""" # Arrange @@ -707,7 +723,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_large_document_batch_processing(self, db_session_with_containers, patched_external_dependencies): + def test_large_document_batch_processing(self, db_session_with_containers: Session, patched_external_dependencies): """Process a batch exactly at configured upload limit. This test patches config only to force a deterministic limit branch while keeping SQL writes real. diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index d457b59d58..e4cbb9e589 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -11,6 +11,8 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker +from sqlalchemy import select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -54,7 +56,7 @@ class TestDealDatasetVectorIndexTask: yield mock_factory @pytest.fixture - def account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """Create an account with an owner tenant for testing. Returns a tuple of (account, tenant) where tenant is guaranteed to be non-None. @@ -72,7 +74,7 @@ class TestDealDatasetVectorIndexTask: return account, tenant def test_deal_dataset_vector_index_task_remove_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful removal of dataset vector index. @@ -130,7 +132,7 @@ class TestDealDatasetVectorIndexTask: assert mock_processor.clean.call_count >= 0 # For now, just check it doesn't fail def test_deal_dataset_vector_index_task_add_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful addition of dataset vector index. @@ -221,7 +223,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load method was called @@ -230,7 +234,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_update_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful update of dataset vector index. @@ -322,7 +326,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "update") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor clean and load methods were called @@ -332,7 +338,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_dataset_not_found_error( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior when dataset is not found. @@ -352,7 +358,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_no_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action when no documents exist for the dataset. @@ -384,7 +390,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_no_segments( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action when documents exist but have no segments. @@ -431,7 +437,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify that no index processor load was called since no segments exist @@ -440,7 +448,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_update_action_no_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test update action when no documents exist for the dataset. @@ -473,7 +481,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_with_exception_handling( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action with exception handling during processing. @@ -564,12 +572,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to error - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.ERROR assert "Test exception during indexing" in updated_document.error def test_deal_dataset_vector_index_task_with_custom_index_type( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with custom index type (QA_INDEX). @@ -635,7 +645,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type @@ -645,7 +657,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_default_index_type( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with default index type (PARAGRAPH_INDEX). @@ -711,7 +723,9 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify document status was updated to indexing then completed - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type @@ -721,7 +735,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_multiple_documents_processing( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task processing with multiple documents and segments. @@ -815,7 +829,9 @@ class TestDealDatasetVectorIndexTask: # Verify all documents were processed for document in documents: - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load was called multiple times @@ -824,7 +840,7 @@ class TestDealDatasetVectorIndexTask: assert mock_processor.load.call_count == 3 def test_deal_dataset_vector_index_task_document_status_transitions( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test document status transitions during task execution. @@ -917,11 +933,13 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify final document status - updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == document.id).limit(1) + ) assert updated_document.indexing_status == IndexingStatus.COMPLETED def test_deal_dataset_vector_index_task_with_disabled_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with disabled documents. @@ -1027,12 +1045,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only enabled document was processed - updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first() + updated_enabled_document = db_session_with_containers.scalar( + select(Document).where(Document.id == enabled_document.id).limit(1) + ) assert updated_enabled_document.indexing_status == IndexingStatus.COMPLETED # Verify disabled document status remains unchanged - updated_disabled_document = ( - db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first() + updated_disabled_document = db_session_with_containers.scalar( + select(Document).where(Document.id == disabled_document.id).limit(1) ) assert updated_disabled_document.indexing_status == IndexingStatus.COMPLETED # Should not change @@ -1042,7 +1062,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_archived_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with archived documents. @@ -1148,12 +1168,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only active document was processed - updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first() + updated_active_document = db_session_with_containers.scalar( + select(Document).where(Document.id == active_document.id).limit(1) + ) assert updated_active_document.indexing_status == IndexingStatus.COMPLETED # Verify archived document status remains unchanged - updated_archived_document = ( - db_session_with_containers.query(Document).filter_by(id=archived_document.id).first() + updated_archived_document = db_session_with_containers.scalar( + select(Document).where(Document.id == archived_document.id).limit(1) ) assert updated_archived_document.indexing_status == IndexingStatus.COMPLETED # Should not change @@ -1163,7 +1185,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_incomplete_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with documents that have incomplete indexing status. @@ -1269,14 +1291,14 @@ class TestDealDatasetVectorIndexTask: deal_dataset_vector_index_task(dataset.id, "add") # Verify only completed document was processed - updated_completed_document = ( - db_session_with_containers.query(Document).filter_by(id=completed_document.id).first() + updated_completed_document = db_session_with_containers.scalar( + select(Document).where(Document.id == completed_document.id).limit(1) ) assert updated_completed_document.indexing_status == IndexingStatus.COMPLETED # Verify incomplete document status remains unchanged - updated_incomplete_document = ( - db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first() + updated_incomplete_document = db_session_with_containers.scalar( + select(Document).where(Document.id == incomplete_document.id).limit(1) ) assert updated_incomplete_document.indexing_status == IndexingStatus.INDEXING # Should not change diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 8a69707b38..f4a71040c1 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -11,9 +11,19 @@ import logging from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Dataset, Document, DocumentSegment, Tenant +from models import ( + Account, + AccountStatus, + Dataset, + DatasetPermissionEnum, + Document, + DocumentSegment, + Tenant, + TenantStatus, +) from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -37,7 +47,7 @@ class TestDeleteSegmentFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_tenant(self, db_session_with_containers, fake=None): + def _create_test_tenant(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test tenant with realistic data. @@ -49,7 +59,7 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal") + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status=TenantStatus.NORMAL) tenant.id = fake.uuid4() tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -58,7 +68,7 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return tenant - def _create_test_account(self, db_session_with_containers, tenant, fake=None): + def _create_test_account(self, db_session_with_containers: Session, tenant, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -75,7 +85,7 @@ class TestDeleteSegmentFromIndexTask: name=fake.name(), email=fake.email(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", ) account.id = fake.uuid4() @@ -86,7 +96,9 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return account - def _create_test_dataset(self, db_session_with_containers, tenant, account, fake=None): + def _create_test_dataset( + self, db_session_with_containers: Session, tenant: Tenant, account: Account, fake: Faker | None = None + ): """ Helper method to create a test dataset with realistic data. @@ -106,7 +118,7 @@ class TestDeleteSegmentFromIndexTask: dataset.name = f"Test Dataset {fake.word()}" dataset.description = fake.text(max_nb_chars=200) dataset.provider = "vendor" - dataset.permission = "only_me" + dataset.permission = DatasetPermissionEnum.ONLY_ME dataset.data_source_type = DataSourceType.UPLOAD_FILE dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' @@ -122,7 +134,7 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, dataset, account, fake=None, **kwargs): + def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None, **kwargs): """ Helper method to create a test document with realistic data. @@ -172,7 +184,14 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return document - def _create_test_document_segments(self, db_session_with_containers, document, account, count=3, fake=None): + def _create_test_document_segments( + self, + db_session_with_containers: Session, + document: Document, + account: Account, + count: int = 3, + fake: Faker | None = None, + ): """ Helper method to create test document segments with realistic data. @@ -218,7 +237,9 @@ class TestDeleteSegmentFromIndexTask: return segments @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True) - def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers): + def test_delete_segment_from_index_task_success( + self, mock_index_processor_factory, db_session_with_containers: Session + ): """ Test successful segment deletion from index with comprehensive verification. @@ -267,7 +288,7 @@ class TestDeleteSegmentFromIndexTask: assert call_args[1]["with_keywords"] is True assert call_args[1]["delete_child_chunks"] is True - def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers): + def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers: Session): """ Test task behavior when dataset is not found. @@ -288,7 +309,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when dataset not found - def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers: Session): """ Test task behavior when document is not found. @@ -314,7 +335,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document not found - def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers: Session): """ Test task behavior when document is disabled. @@ -342,7 +363,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document is disabled - def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers: Session): """ Test task behavior when document is archived. @@ -370,7 +391,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document is archived - def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers: Session): """ Test task behavior when document indexing is not completed. diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 3e9a0c8f7f..6bfb1e1f1e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -9,10 +9,11 @@ The task is responsible for removing document segments from the search index whe from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Dataset, DocumentSegment +from models import Account, AccountStatus, Dataset, DocumentSegment, TenantAccountRole, TenantStatus from models import Document as DatasetDocument from models.dataset import DatasetProcessRule from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus @@ -34,7 +35,7 @@ class TestDisableSegmentsFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers: Session, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -50,24 +51,23 @@ class TestDisableSegmentsFromIndexTask: email=fake.email(), name=fake.name(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", ) - account.id = fake.uuid4() # monkey-patch attributes for test setup + account.updated_at = fake.date_time_this_year() + account.created_at = fake.date_time_this_year() + account.role = TenantAccountRole.OWNER + account.id = fake.uuid4() account.tenant_id = fake.uuid4() account.type = "normal" - account.role = "owner" - account.created_at = fake.date_time_this_year() - account.updated_at = account.created_at - # Create a tenant for the account from models.account import Tenant tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="normal", + status=TenantStatus.NORMAL, ) tenant.id = account.tenant_id tenant.created_at = fake.date_time_this_year() @@ -82,7 +82,7 @@ class TestDisableSegmentsFromIndexTask: return account - def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None): + def _create_test_dataset(self, db_session_with_containers: Session, account, fake: Faker | None = None): """ Helper method to create a test dataset with realistic data. @@ -116,7 +116,9 @@ class TestDisableSegmentsFromIndexTask: return dataset - def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None): + def _create_test_document( + self, db_session_with_containers: Session, dataset, account: Account, fake: Faker | None = None + ): """ Helper method to create a test document with realistic data. @@ -215,7 +217,7 @@ class TestDisableSegmentsFromIndexTask: return segments - def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None): + def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake: Faker | None = None): """ Helper method to create a dataset process rule. @@ -471,9 +473,9 @@ class TestDisableSegmentsFromIndexTask: db_session_with_containers.refresh(segments[1]) # Check that segments are re-enabled after error - updated_segments = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() - ) + updated_segments = db_session_with_containers.scalars( + select(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) + ).all() for segment in updated_segments: assert segment.enabled is True diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index d4021143ef..77cd259833 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -12,10 +12,12 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy import delete, func, select, update +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -30,12 +32,12 @@ class DocumentIndexingSyncTaskTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.flush() - tenant = Tenant(name=f"tenant-{account.id}", status="normal") + tenant = Tenant(name=f"tenant-{account.id}", status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.flush() @@ -161,7 +163,7 @@ class TestDocumentIndexingSyncTask: "indexing_runner": indexing_runner, } - def _create_notion_sync_context(self, db_session_with_containers, *, data_source_info: dict | None = None): + def _create_notion_sync_context(self, db_session_with_containers: Session, *, data_source_info: dict | None = None): account, tenant = DocumentIndexingSyncTaskTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DocumentIndexingSyncTaskTestDataFactory.create_dataset( db_session_with_containers, @@ -205,7 +207,7 @@ class TestDocumentIndexingSyncTask: "notion_info": notion_info, } - def test_document_not_found(self, db_session_with_containers, mock_external_dependencies): + def test_document_not_found(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task handles missing document gracefully.""" # Arrange dataset_id = str(uuid4()) @@ -218,7 +220,7 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["datasource_service"].get_datasource_credentials.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_missing_notion_workspace_id(self, db_session_with_containers, mock_external_dependencies): + def test_missing_notion_workspace_id(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when notion_workspace_id is missing.""" # Arrange context = self._create_notion_sync_context( @@ -234,7 +236,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_missing_notion_page_id(self, db_session_with_containers, mock_external_dependencies): + def test_missing_notion_page_id(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when notion_page_id is missing.""" # Arrange context = self._create_notion_sync_context( @@ -250,12 +252,12 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_empty_data_source_info(self, db_session_with_containers, mock_external_dependencies): + def test_empty_data_source_info(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when data_source_info is empty.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None) - db_session_with_containers.query(Document).where(Document.id == context["document"].id).update( - {"data_source_info": None} + db_session_with_containers.execute( + update(Document).where(Document.id == context["document"].id).values(data_source_info=None) ) db_session_with_containers.commit() @@ -263,7 +265,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_credential_not_found(self, db_session_with_containers, mock_external_dependencies): + def test_credential_not_found(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task sets document error state when credential is missing.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -274,8 +276,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR @@ -283,7 +285,7 @@ class TestDocumentIndexingSyncTask: assert updated_document.stopped_at is not None mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_page_not_updated(self, db_session_with_containers, mock_external_dependencies): + def test_page_not_updated(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task exits early when notion page is unchanged.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -294,13 +296,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.COMPLETED @@ -309,7 +311,7 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_successful_sync_when_page_updated(self, db_session_with_containers, mock_external_dependencies): + def test_successful_sync_when_page_updated(self, db_session_with_containers: Session, mock_external_dependencies): """Test full successful sync flow with SQL state updates and side effects.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -319,13 +321,13 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None @@ -348,13 +350,13 @@ class TestDocumentIndexingSyncTask: assert len(run_documents) == 1 assert getattr(run_documents[0], "id", None) == context["document"].id - def test_dataset_not_found_during_cleaning(self, db_session_with_containers, mock_external_dependencies): + def test_dataset_not_found_during_cleaning(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task still updates document and reindexes if dataset vanishes before clean.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) def _delete_dataset_before_clean() -> str: - db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete() + db_session_with_containers.execute(delete(Dataset).where(Dataset.id == context["dataset"].id)) db_session_with_containers.commit() return "2024-01-02T00:00:00Z" @@ -367,15 +369,17 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_called_once() - def test_cleaning_error_continues_to_indexing(self, db_session_with_containers, mock_external_dependencies): + def test_cleaning_error_continues_to_indexing( + self, db_session_with_containers: Session, mock_external_dependencies + ): """Test that indexing continues when index cleanup fails.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -386,20 +390,22 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) - remaining_segments = ( - db_session_with_containers.query(DocumentSegment) + remaining_segments = db_session_with_containers.scalar( + select(func.count()) + .select_from(DocumentSegment) .where(DocumentSegment.document_id == context["document"].id) - .count() ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING assert remaining_segments == 0 mock_external_dependencies["indexing_runner"].run.assert_called_once() - def test_indexing_runner_document_paused_error(self, db_session_with_containers, mock_external_dependencies): + def test_indexing_runner_document_paused_error( + self, db_session_with_containers: Session, mock_external_dependencies + ): """Test that DocumentIsPausedError does not flip document into error state.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -410,14 +416,14 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.error is None - def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies): + def test_indexing_runner_general_error(self, db_session_with_containers: Session, mock_external_dependencies): """Test that indexing errors are persisted to document state.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -428,8 +434,8 @@ class TestDocumentIndexingSyncTask: # Assert db_session_with_containers.expire_all() - updated_document = ( - db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() + updated_document = db_session_with_containers.scalar( + select(Document).where(Document.id == context["document"].id).limit(1) ) assert updated_document is not None assert updated_document.indexing_status == IndexingStatus.ERROR diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index cf1a8666f3..6c1454b6d8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -3,11 +3,12 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.entities.document_task import DocumentTask from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from tasks.document_indexing_task import ( @@ -51,7 +52,7 @@ class TestDocumentIndexingTasks: } def _create_test_dataset_and_documents( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, document_count=3 ): """ Helper method to create a test dataset and documents for testing. @@ -71,14 +72,14 @@ class TestDocumentIndexingTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -133,7 +134,7 @@ class TestDocumentIndexingTasks: return dataset, documents def _create_test_dataset_with_billing_features( - self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + self, db_session_with_containers: Session, mock_external_service_dependencies, billing_enabled=True ): """ Helper method to create a test dataset with billing features configured. @@ -153,14 +154,14 @@ class TestDocumentIndexingTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -221,7 +222,9 @@ class TestDocumentIndexingTasks: return dataset, documents - def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_document_indexing_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful document indexing with multiple documents. @@ -262,7 +265,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 3 def test_document_indexing_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -286,7 +289,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def test_document_indexing_task_document_not_found_in_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when some documents don't exist in the dataset. @@ -332,7 +335,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 2 # Only existing documents def test_document_indexing_task_indexing_runner_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of IndexingRunner exceptions. @@ -373,7 +376,7 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None def test_document_indexing_task_mixed_document_states( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test processing documents with mixed initial states. @@ -456,7 +459,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 4 def test_document_indexing_task_billing_sandbox_plan_batch_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for sandbox plan batch upload limit. @@ -518,7 +521,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner"].assert_not_called() def test_document_indexing_task_billing_disabled_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful processing when billing is disabled. @@ -554,7 +557,7 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None def test_document_indexing_task_document_is_paused_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of DocumentIsPausedError from IndexingRunner. @@ -597,7 +600,9 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None # ==================== NEW TESTS FOR REFACTORED FUNCTIONS ==================== - def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_old_document_indexing_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test document_indexing_task basic functionality. @@ -619,7 +624,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_normal_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test normal_document_indexing_task basic functionality. @@ -643,7 +648,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_priority_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test priority_document_indexing_task basic functionality. @@ -667,7 +672,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_document_indexing_with_tenant_queue_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test _document_indexing_with_tenant_queue function with no waiting tasks. @@ -717,7 +722,7 @@ class TestDocumentIndexingTasks: mock_task_func.delay.assert_not_called() def test_document_indexing_with_tenant_queue_with_waiting_tasks( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis. @@ -776,7 +781,7 @@ class TestDocumentIndexingTasks: assert len(remaining_tasks) == 1 def test_document_indexing_with_tenant_queue_error_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling in _document_indexing_with_tenant_queue using real Redis. @@ -848,7 +853,7 @@ class TestDocumentIndexingTasks: assert len(remaining_tasks) == 0 def test_document_indexing_with_tenant_queue_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation in _document_indexing_with_tenant_queue using real Redis. diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index d94abf2b40..208fc1aa1d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -2,9 +2,11 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_update_task import document_indexing_update_task @@ -32,7 +34,7 @@ class TestDocumentIndexingUpdateTask: "runner_instance": runner_instance, } - def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2): + def _create_dataset_document_with_segments(self, db_session_with_containers: Session, *, segment_count: int = 2): fake = Faker() # Account and tenant @@ -40,12 +42,12 @@ class TestDocumentIndexingUpdateTask: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() - tenant = Tenant(name=fake.company(), status="normal") + tenant = Tenant(name=fake.company(), status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -113,7 +115,7 @@ class TestDocumentIndexingUpdateTask: return dataset, document, node_ids - def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies): + def test_cleans_segments_and_reindexes(self, db_session_with_containers: Session, mock_external_dependencies): dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) # Act @@ -123,13 +125,13 @@ class TestDocumentIndexingUpdateTask: db_session_with_containers.expire_all() # Assert document status updated before reindex - updated = db_session_with_containers.query(Document).where(Document.id == document.id).first() + updated = db_session_with_containers.scalar(select(Document).where(Document.id == document.id).limit(1)) assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None # Segments should be deleted - remaining = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + remaining = db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) ) assert remaining == 0 @@ -152,7 +154,9 @@ class TestDocumentIndexingUpdateTask: first = run_docs[0] assert getattr(first, "id", None) == document.id - def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies): + def test_clean_error_is_logged_and_indexing_continues( + self, db_session_with_containers: Session, mock_external_dependencies + ): dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) # Force clean to raise; task should continue to indexing @@ -167,12 +171,12 @@ class TestDocumentIndexingUpdateTask: mock_external_dependencies["runner_instance"].run.assert_called_once() # Segments should remain (since clean failed before DB delete) - remaining = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + remaining = db_session_with_containers.scalar( + select(func.count()).select_from(DocumentSegment).where(DocumentSegment.document_id == document.id) ) assert remaining > 0 - def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies): + def test_document_not_found_noop(self, db_session_with_containers: Session, mock_external_dependencies): fake = Faker() # Act with non-existent document id document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index 6a8e186958..12440f3e6b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import select +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -61,7 +63,7 @@ class TestDuplicateDocumentIndexingTasks: } def _create_test_dataset_and_documents( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, document_count=3 ): """ Helper method to create a test dataset and documents for testing. @@ -144,7 +146,11 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents def _create_test_dataset_with_segments( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2 + self, + db_session_with_containers: Session, + mock_external_service_dependencies, + document_count=3, + segments_per_doc=2, ): """ Helper method to create a test dataset with documents and segments. @@ -196,7 +202,7 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents, segments def _create_test_dataset_with_billing_features( - self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + self, db_session_with_containers: Session, mock_external_service_dependencies, billing_enabled=True ): """ Helper method to create a test dataset with billing features configured. @@ -286,7 +292,7 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents def _test_duplicate_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful duplicate document indexing with multiple documents. @@ -317,7 +323,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were updated to parsing status # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -328,7 +334,7 @@ class TestDuplicateDocumentIndexingTasks: assert len(processed_documents) == 3 def _test_duplicate_document_indexing_task_with_segment_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test duplicate document indexing with existing segments that need cleanup. @@ -362,14 +368,14 @@ class TestDuplicateDocumentIndexingTasks: # Verify segments were deleted from database # Re-query segments from database using captured IDs to avoid stale ORM instances for seg_id in segment_ids: - deleted_segment = ( - db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first() + deleted_segment = db_session_with_containers.scalar( + select(DocumentSegment).where(DocumentSegment.id == seg_id).limit(1) ) assert deleted_segment is None # Verify documents were updated to parsing status for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -378,7 +384,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def _test_duplicate_document_indexing_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -403,7 +409,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["index_processor"].clean.assert_not_called() def test_duplicate_document_indexing_task_document_not_found_in_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when some documents don't exist in the dataset. @@ -438,7 +444,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify only existing documents were updated # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in existing_document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None @@ -449,7 +455,7 @@ class TestDuplicateDocumentIndexingTasks: assert len(processed_documents) == 2 # Only existing documents def _test_duplicate_document_indexing_task_indexing_runner_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of IndexingRunner exceptions. @@ -485,12 +491,12 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were still updated to parsing status before the exception # Re-query documents from database since _duplicate_document_indexing_task close the session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for sandbox plan batch upload limit. @@ -543,7 +549,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "batch upload" in updated_document.error.lower() @@ -553,7 +559,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def _test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for vector space limit. @@ -585,7 +591,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert: Verify error handling # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "limit" in updated_document.error.lower() @@ -595,7 +601,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def test_duplicate_document_indexing_task_with_empty_document_list( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of empty document list. @@ -621,7 +627,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) def test_deprecated_duplicate_document_indexing_task_delegates_to_core( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that deprecated duplicate_document_indexing_task delegates to core function. @@ -649,12 +655,12 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_normal_duplicate_document_indexing_task_with_tenant_queue( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test normal_duplicate_document_indexing_task with tenant isolation queue. @@ -692,12 +698,12 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_priority_duplicate_document_indexing_task_with_tenant_queue( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test priority_duplicate_document_indexing_task with tenant isolation queue. @@ -736,12 +742,12 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_tenant_queue_wrapper_processes_next_tasks( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant queue wrapper processes next queued tasks. @@ -788,7 +794,7 @@ class TestDuplicateDocumentIndexingTasks: mock_queue.delete_task_key.assert_not_called() def test_successful_duplicate_document_indexing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test successful duplicate document indexing flow.""" self._test_duplicate_document_indexing_task_success( @@ -796,7 +802,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when dataset is not found.""" self._test_duplicate_document_indexing_task_dataset_not_found( @@ -804,7 +810,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing with billing enabled and sandbox plan.""" self._test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( @@ -812,7 +818,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_with_billing_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when billing limit is exceeded.""" self._test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( @@ -820,7 +826,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_runner_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when IndexingRunner raises an error.""" self._test_duplicate_document_indexing_task_indexing_runner_exception( @@ -828,7 +834,7 @@ class TestDuplicateDocumentIndexingTasks: ) def _test_duplicate_document_indexing_task_document_is_paused( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when document is paused.""" # Arrange @@ -851,7 +857,7 @@ class TestDuplicateDocumentIndexingTasks: # Assert for doc_id in document_ids: - updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + updated_document = db_session_with_containers.scalar(select(Document).where(Document.id == doc_id).limit(1)) assert updated_document.is_paused is True assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.display_status == "paused" @@ -859,7 +865,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_duplicate_document_indexing_document_is_paused( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when document is paused.""" self._test_duplicate_document_indexing_task_document_is_paused( @@ -867,7 +873,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_cleans_old_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test that duplicate document indexing cleans old segments.""" self._test_duplicate_document_indexing_task_with_segment_cleanup( diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py index 177af266fb..a697878bb6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -29,7 +30,7 @@ class TestMailChangeMailTask: "get_email_i18n_service": mock_get_email_i18n_service, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -72,7 +73,7 @@ class TestMailChangeMailTask: return account def test_send_change_mail_task_success_old_email_phase( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email task execution for old_email phase. @@ -103,7 +104,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_task_success_new_email_phase( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email task execution for new_email phase. @@ -134,7 +135,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email task when mail service is not initialized. @@ -159,7 +160,7 @@ class TestMailChangeMailTask: mock_external_service_dependencies["email_i18n_service"].send_change_email.assert_not_called() def test_send_change_mail_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email task when email service raises an exception. @@ -191,7 +192,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_completed_notification_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email completed notification task execution. @@ -224,7 +225,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_completed_notification_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email completed notification task when mail service is not initialized. @@ -247,7 +248,7 @@ class TestMailChangeMailTask: mock_external_service_dependencies["email_i18n_service"].send_email.assert_not_called() def test_send_change_mail_completed_notification_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email completed notification task when email service raises an exception. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index c0ddc27286..8e9da6aaaa 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -14,6 +14,8 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -36,14 +38,14 @@ class TestSendEmailCodeLoginMailTask: """ @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" from extensions.ext_redis import redis_client # Clear all test data - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -70,7 +72,7 @@ class TestSendEmailCodeLoginMailTask: "email_service_instance": mock_email_service_instance, } - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account for testing. @@ -97,7 +99,7 @@ class TestSendEmailCodeLoginMailTask: return account - def _create_test_tenant_and_account(self, db_session_with_containers, fake=None): + def _create_test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test tenant and account for testing. @@ -137,7 +139,7 @@ class TestSendEmailCodeLoginMailTask: return account, tenant def test_send_email_code_login_mail_task_success_english( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending in English. @@ -181,7 +183,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_success_chinese( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending in Chinese. @@ -220,7 +222,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_success_multiple_languages( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending with multiple languages. @@ -260,7 +262,7 @@ class TestSendEmailCodeLoginMailTask: assert call_args[1]["template_context"]["code"] == test_codes[i] def test_send_email_code_login_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task when mail service is not initialized. @@ -298,7 +300,7 @@ class TestSendEmailCodeLoginMailTask: mock_email_service_instance.send_email.assert_not_called() def test_send_email_code_login_mail_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task when email service raises an exception. @@ -345,7 +347,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_invalid_parameters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with invalid parameters. @@ -387,7 +389,7 @@ class TestSendEmailCodeLoginMailTask: mock_email_service_instance.send_email.assert_called_once() def test_send_email_code_login_mail_task_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with edge cases and boundary conditions. @@ -450,7 +452,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_database_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with database integration. @@ -496,7 +498,7 @@ class TestSendEmailCodeLoginMailTask: assert account.status == "active" def test_send_email_code_login_mail_task_redis_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with Redis integration. @@ -540,7 +542,7 @@ class TestSendEmailCodeLoginMailTask: redis_client.delete(cache_key) def test_send_email_code_login_mail_task_error_handling_comprehensive( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test comprehensive error handling for email code login mail task. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index a16f3ff773..f505361727 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,16 +3,15 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import HumanInputNodeData -from graphon.runtime import GraphRuntimeState, VariablePool +from sqlalchemy import delete +from sqlalchemy.orm import Session from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -20,6 +19,9 @@ from core.workflow.human_input_compat import ( MemberRecipient, ) from extensions.ext_storage import storage +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import HumanInputDelivery, HumanInputForm, HumanInputFormRecipient @@ -30,14 +32,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @pytest.fixture(autouse=True) def cleanup_database(db_session_with_containers): - db_session_with_containers.query(HumanInputFormRecipient).delete() - db_session_with_containers.query(HumanInputDelivery).delete() - db_session_with_containers.query(HumanInputForm).delete() - db_session_with_containers.query(WorkflowPause).delete() - db_session_with_containers.query(WorkflowRun).delete() - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(HumanInputFormRecipient)) + db_session_with_containers.execute(delete(HumanInputDelivery)) + db_session_with_containers.execute(delete(HumanInputForm)) + db_session_with_containers.execute(delete(WorkflowPause)) + db_session_with_containers.execute(delete(WorkflowRun)) + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() @@ -171,7 +173,9 @@ def _create_workflow_pause_state( db_session_with_containers.commit() -def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers): +def test_dispatch_human_input_email_task_integration( + monkeypatch: pytest.MonkeyPatch, db_session_with_containers: Session +): tenant, account = _create_workspace_member(db_session_with_containers) workflow_run_id = str(uuid.uuid4()) workflow_id = str(uuid.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py index 1a20b6deec..f8e54ea9e6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from tasks.mail_inner_task import send_inner_email_task @@ -51,7 +52,7 @@ class TestMailInnerTask: }, } - def test_send_inner_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful email sending with valid data. @@ -90,7 +91,9 @@ class TestMailInnerTask: html_content="Test email content", ) - def test_send_inner_email_single_recipient(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_single_recipient( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending with single recipient. @@ -126,7 +129,9 @@ class TestMailInnerTask: html_content="Test email content", ) - def test_send_inner_email_empty_substitutions(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_empty_substitutions( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending with empty substitutions. @@ -163,7 +168,7 @@ class TestMailInnerTask: ) def test_send_inner_email_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending when mail service is not initialized. @@ -193,7 +198,7 @@ class TestMailInnerTask: mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called() def test_send_inner_email_template_rendering_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending when template rendering fails. @@ -222,7 +227,9 @@ class TestMailInnerTask: # Verify no email service calls due to exception mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called() - def test_send_inner_email_service_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_service_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending when email service fails. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py index 212fbd26cd..c8c7a4d961 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -17,6 +17,8 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy import delete, select +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client from libs.email_i18n import EmailType @@ -41,12 +43,12 @@ class TestMailInviteMemberTask: """ @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" # Clear all test data - db_session_with_containers.query(TenantAccountJoin).delete() - db_session_with_containers.query(Tenant).delete() - db_session_with_containers.query(Account).delete() + db_session_with_containers.execute(delete(TenantAccountJoin)) + db_session_with_containers.execute(delete(Tenant)) + db_session_with_containers.execute(delete(Account)) db_session_with_containers.commit() # Clear Redis cache @@ -77,7 +79,7 @@ class TestMailInviteMemberTask: "config": mock_config, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -146,7 +148,7 @@ class TestMailInviteMemberTask: redis_client.setex(cache_key, 24 * 60 * 60, json.dumps(invitation_data)) # 24 hours return token - def _create_pending_account_for_invitation(self, db_session_with_containers, email, tenant): + def _create_pending_account_for_invitation(self, db_session_with_containers: Session, email, tenant): """ Helper method to create a pending account for invitation testing. @@ -184,7 +186,9 @@ class TestMailInviteMemberTask: return account - def test_send_invite_member_mail_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_invite_member_mail_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful invitation email sending with all parameters. @@ -230,7 +234,7 @@ class TestMailInviteMemberTask: assert template_context["url"] == f"https://console.dify.ai/activate?token={token}" def test_send_invite_member_mail_different_languages( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test invitation email sending with different language codes. @@ -262,7 +266,7 @@ class TestMailInviteMemberTask: assert call_args[1]["language_code"] == language def test_send_invite_member_mail_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test behavior when mail service is not initialized. @@ -291,7 +295,7 @@ class TestMailInviteMemberTask: mock_email_service.send_email.assert_not_called() def test_send_invite_member_mail_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when email service raises an exception. @@ -321,7 +325,7 @@ class TestMailInviteMemberTask: assert "Send invite member mail to %s failed" in error_call def test_send_invite_member_mail_template_context_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test template context contains all required fields for email rendering. @@ -367,7 +371,7 @@ class TestMailInviteMemberTask: assert template_context["url"] == f"https://console.dify.ai/activate?token={token}" def test_send_invite_member_mail_integration_with_redis_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test integration with Redis token validation. @@ -406,7 +410,7 @@ class TestMailInviteMemberTask: assert invitation_data["workspace_id"] == tenant.id def test_send_invite_member_mail_with_special_characters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending with special characters in names and workspace names. @@ -448,7 +452,7 @@ class TestMailInviteMemberTask: assert template_context["workspace_name"] == workspace_name def test_send_invite_member_mail_real_database_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test real database integration with actual invitation flow. @@ -491,16 +495,16 @@ class TestMailInviteMemberTask: assert tenant.name is not None # Verify tenant relationship exists - tenant_join = ( - db_session_with_containers.query(TenantAccountJoin) - .filter_by(tenant_id=tenant.id, account_id=pending_account.id) - .first() + tenant_join = db_session_with_containers.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == pending_account.id) + .limit(1) ) assert tenant_join is not None assert tenant_join.role == TenantAccountRole.NORMAL def test_send_invite_member_mail_token_lifecycle_management( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test token lifecycle management and validation. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py index e08b099480..176645a4ab 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py @@ -11,6 +11,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -44,7 +45,7 @@ class TestMailOwnerTransferTask: "get_email_service": mock_get_email_service, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create test account and tenant for testing. @@ -86,7 +87,9 @@ class TestMailOwnerTransferTask: return account, tenant - def test_send_owner_transfer_confirm_task_success(self, db_session_with_containers, mock_mail_dependencies): + def test_send_owner_transfer_confirm_task_success( + self, db_session_with_containers: Session, mock_mail_dependencies + ): """ Test successful owner transfer confirmation email sending. @@ -127,7 +130,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace def test_send_owner_transfer_confirm_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test owner transfer confirmation email when mail service is not initialized. @@ -158,7 +161,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_owner_transfer_confirm_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in owner transfer confirmation email. @@ -192,7 +195,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_called_once() def test_send_old_owner_transfer_notify_email_task_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test successful old owner transfer notification email sending. @@ -234,7 +237,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["NewOwnerEmail"] == test_new_owner_email def test_send_old_owner_transfer_notify_email_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test old owner transfer notification email when mail service is not initialized. @@ -265,7 +268,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_old_owner_transfer_notify_email_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in old owner transfer notification email. @@ -299,7 +302,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_called_once() def test_send_new_owner_transfer_notify_email_task_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test successful new owner transfer notification email sending. @@ -338,7 +341,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace def test_send_new_owner_transfer_notify_email_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test new owner transfer notification email when mail service is not initialized. @@ -367,7 +370,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_new_owner_transfer_notify_email_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in new owner transfer notification email. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py index cced6f7780..071971f324 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py @@ -9,6 +9,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist @@ -35,7 +36,7 @@ class TestMailRegisterTask: "get_email_service": mock_get_email_service, } - def test_send_email_register_mail_task_success(self, db_session_with_containers, mock_mail_dependencies): + def test_send_email_register_mail_task_success(self, db_session_with_containers: Session, mock_mail_dependencies): """Test successful email registration mail sending.""" fake = Faker() language = "en-US" @@ -56,7 +57,7 @@ class TestMailRegisterTask: ) def test_send_email_register_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test email registration task when mail service is not initialized.""" mock_mail_dependencies["mail"].is_inited.return_value = False @@ -66,7 +67,9 @@ class TestMailRegisterTask: mock_mail_dependencies["get_email_service"].assert_not_called() mock_mail_dependencies["email_service"].send_email.assert_not_called() - def test_send_email_register_mail_task_exception_handling(self, db_session_with_containers, mock_mail_dependencies): + def test_send_email_register_mail_task_exception_handling( + self, db_session_with_containers: Session, mock_mail_dependencies + ): """Test email registration task exception handling.""" mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") @@ -79,7 +82,7 @@ class TestMailRegisterTask: mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) def test_send_email_register_mail_task_when_account_exist_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test successful email registration mail sending when account exists.""" fake = Faker() @@ -105,7 +108,7 @@ class TestMailRegisterTask: ) def test_send_email_register_mail_task_when_account_exist_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test account exist email task when mail service is not initialized.""" mock_mail_dependencies["mail"].is_inited.return_value = False @@ -118,7 +121,7 @@ class TestMailRegisterTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_email_register_mail_task_when_account_exist_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test account exist email task exception handling.""" mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index f01fcc1742..5eea985fdc 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -4,12 +4,13 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from flask import Flask from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Pipeline from models.workflow import Workflow from tasks.rag_pipeline.priority_rag_pipeline_run_task import ( @@ -69,14 +70,14 @@ class TestRagPipelineRunTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -725,7 +726,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_run_single_rag_pipeline_task_success( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test successful run_single_rag_pipeline_task execution. @@ -760,7 +761,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_run_single_rag_pipeline_task_entity_validation_error( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test run_single_rag_pipeline_task with invalid entity data. @@ -805,7 +806,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_run_single_rag_pipeline_task_database_entity_not_found( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test run_single_rag_pipeline_task with non-existent database entities. diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 96cf9cebf5..03c02ea341 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,11 +2,13 @@ import uuid from unittest.mock import ANY, call, patch import pytest -from graphon.variables.segments import StringSegment -from graphon.variables.types import SegmentType +from sqlalchemy import delete, func, select +from sqlalchemy.orm import Session from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole @@ -20,11 +22,11 @@ from tasks.remove_app_and_related_data_task import ( @pytest.fixture(autouse=True) def cleanup_database(db_session_with_containers): - db_session_with_containers.query(WorkflowDraftVariable).delete() - db_session_with_containers.query(WorkflowDraftVariableFile).delete() - db_session_with_containers.query(UploadFile).delete() - db_session_with_containers.query(App).delete() - db_session_with_containers.query(Tenant).delete() + db_session_with_containers.execute(delete(WorkflowDraftVariable)) + db_session_with_containers.execute(delete(WorkflowDraftVariableFile)) + db_session_with_containers.execute(delete(UploadFile)) + db_session_with_containers.execute(delete(App)) + db_session_with_containers.execute(delete(Tenant)) db_session_with_containers.commit() @@ -116,7 +118,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id: class TestDeleteDraftVariablesBatch: - def test_delete_draft_variables_batch_success(self, db_session_with_containers): + def test_delete_draft_variables_batch_success(self, db_session_with_containers: Session): """Test successful deletion of draft variables in batches.""" _, app1 = _create_tenant_and_app(db_session_with_containers) _, app2 = _create_tenant_and_app(db_session_with_containers) @@ -127,21 +129,21 @@ class TestDeleteDraftVariablesBatch: result = delete_draft_variables_batch(app1.id, batch_size=100) assert result == 150 - app1_remaining = db_session_with_containers.query(WorkflowDraftVariable).where( - WorkflowDraftVariable.app_id == app1.id + app1_remaining_count = db_session_with_containers.scalar( + select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app1.id) ) - app2_remaining = db_session_with_containers.query(WorkflowDraftVariable).where( - WorkflowDraftVariable.app_id == app2.id + app2_remaining_count = db_session_with_containers.scalar( + select(func.count()).select_from(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app2.id) ) - assert app1_remaining.count() == 0 - assert app2_remaining.count() == 100 + assert app1_remaining_count == 0 + assert app2_remaining_count == 100 - def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers): + def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers: Session): """Test deletion when no draft variables exist for the app.""" result = delete_draft_variables_batch(str(uuid.uuid4()), 1000) assert result == 0 - assert db_session_with_containers.query(WorkflowDraftVariable).count() == 0 + assert db_session_with_containers.scalar(select(func.count()).select_from(WorkflowDraftVariable)) == 0 @patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data") @patch("tasks.remove_app_and_related_data_task.logger") @@ -175,7 +177,7 @@ class TestDeleteDraftVariableOffloadData: """Test the Offload data cleanup functionality.""" @patch("extensions.ext_storage.storage") - def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers): + def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers: Session): """Test successful deletion of offload data.""" tenant, app = _create_tenant_and_app(db_session_with_containers) offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=3) @@ -190,12 +192,16 @@ class TestDeleteDraftVariableOffloadData: expected_storage_calls = [call(storage_key) for storage_key in upload_file_keys] mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True) - remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where( - WorkflowDraftVariableFile.id.in_(file_ids) + remaining_var_files_count = db_session_with_containers.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(file_ids)) ) - remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)) - assert remaining_var_files.count() == 0 - assert remaining_upload_files.count() == 0 + remaining_upload_files_count = db_session_with_containers.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) + assert remaining_var_files_count == 0 + assert remaining_upload_files_count == 0 @patch("extensions.ext_storage.storage") @patch("tasks.remove_app_and_related_data_task.logging") @@ -217,9 +223,13 @@ class TestDeleteDraftVariableOffloadData: assert result == 1 mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", storage_keys[0]) - remaining_var_files = db_session_with_containers.query(WorkflowDraftVariableFile).where( - WorkflowDraftVariableFile.id.in_(file_ids) + remaining_var_files_count = db_session_with_containers.scalar( + select(func.count()) + .select_from(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(file_ids)) ) - remaining_upload_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)) - assert remaining_var_files.count() == 0 - assert remaining_upload_files.count() == 0 + remaining_upload_files_count = db_session_with_containers.scalar( + select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids)) + ) + assert remaining_var_files_count == 0 + assert remaining_upload_files_count == 0 diff --git a/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py index 34a1941c39..6365207661 100644 --- a/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py +++ b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py @@ -1,12 +1,14 @@ from pathlib import Path +import pytest + from extensions.storage.opendal_storage import OpenDALStorage class TestOpenDALFsDefaultRoot: """Test that OpenDALStorage with scheme='fs' works correctly when no root is provided.""" - def test_fs_without_root_uses_default(self, tmp_path, monkeypatch): + def test_fs_without_root_uses_default(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """When no root is specified, the default 'storage' should be used and passed to the Operator.""" # Change to tmp_path so the default "storage" dir is created there monkeypatch.chdir(tmp_path) @@ -25,7 +27,7 @@ class TestOpenDALFsDefaultRoot: # Cleanup storage.delete("test_default_root.txt") - def test_fs_with_explicit_root(self, tmp_path): + def test_fs_with_explicit_root(self, tmp_path: Path): """When root is explicitly provided, it should be used.""" custom_root = str(tmp_path / "custom_storage") storage = OpenDALStorage(scheme="fs", root=custom_root) @@ -38,7 +40,7 @@ class TestOpenDALFsDefaultRoot: # Cleanup storage.delete("test_explicit_root.txt") - def test_fs_with_env_var_root(self, tmp_path, monkeypatch): + def test_fs_with_env_var_root(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs.""" env_root = str(tmp_path / "env_storage") monkeypatch.setenv("OPENDAL_FS_ROOT", env_root) diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 159ab51304..6402e7da2b 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,16 +24,16 @@ from dataclasses import dataclass from datetime import timedelta import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus -from sqlalchemy import delete, select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session, selectinload, sessionmaker from extensions.ext_storage import storage +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models import Account from models import WorkflowPause as WorkflowPauseModel -from models.account import Tenant, TenantAccountJoin, TenantAccountRole +from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.model import UploadFile from models.workflow import Workflow, WorkflowRun from repositories.sqlalchemy_api_workflow_run_repository import ( @@ -175,13 +175,13 @@ class TestWorkflowPauseIntegration: """Comprehensive integration tests for workflow pause functionality.""" @pytest.fixture(autouse=True) - def setup_test_data(self, db_session_with_containers): + def setup_test_data(self, db_session_with_containers: Session): """Set up test data for each test method using TestContainers.""" # Create test tenant and account tenant = Tenant( name="Test Tenant", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -190,7 +190,7 @@ class TestWorkflowPauseIntegration: email="test@example.com", name="Test User", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -679,9 +679,12 @@ class TestWorkflowPauseIntegration: # Verify only 3 were deleted remaining_count = ( - self.session.query(WorkflowPauseModel) - .filter(WorkflowPauseModel.id.in_([pe.id for pe in pause_entities])) - .count() + self.session.scalar( + select(func.count(WorkflowPauseModel.id)).where( + WorkflowPauseModel.id.in_([pe.id for pe in pause_entities]) + ) + ) + or 0 ) assert remaining_count == 2 @@ -693,7 +696,7 @@ class TestWorkflowPauseIntegration: tenant2 = Tenant( name="Test Tenant 2", - status="normal", + status=TenantStatus.NORMAL, ) self.session.add(tenant2) self.session.commit() @@ -702,7 +705,7 @@ class TestWorkflowPauseIntegration: email="test2@example.com", name="Test User 2", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) self.session.add(account2) self.session.commit() diff --git a/api/tests/test_containers_integration_tests/trigger/conftest.py b/api/tests/test_containers_integration_tests/trigger/conftest.py index e3832fb2ef..272bee9630 100644 --- a/api/tests/test_containers_integration_tests/trigger/conftest.py +++ b/api/tests/test_containers_integration_tests/trigger/conftest.py @@ -11,6 +11,7 @@ from collections.abc import Generator from typing import Any import pytest +from sqlalchemy import delete from sqlalchemy.orm import Session from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -40,9 +41,9 @@ def tenant_and_account(db_session_with_containers: Session) -> Generator[tuple[T yield tenant, account # Cleanup - db_session_with_containers.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() - db_session_with_containers.query(Account).filter_by(id=account.id).delete() - db_session_with_containers.query(Tenant).filter_by(id=tenant.id).delete() + db_session_with_containers.execute(delete(TenantAccountJoin).where(TenantAccountJoin.tenant_id == tenant.id)) + db_session_with_containers.execute(delete(Account).where(Account.id == account.id)) + db_session_with_containers.execute(delete(Tenant).where(Tenant.id == tenant.id)) db_session_with_containers.commit() @@ -93,14 +94,14 @@ def app_model( ) from models.workflow import Workflow - db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app.id).delete() - db_session_with_containers.query(WorkflowSchedulePlan).filter_by(app_id=app.id).delete() - db_session_with_containers.query(WorkflowWebhookTrigger).filter_by(app_id=app.id).delete() - db_session_with_containers.query(WorkflowPluginTrigger).filter_by(app_id=app.id).delete() - db_session_with_containers.query(AppTrigger).filter_by(app_id=app.id).delete() - db_session_with_containers.query(TriggerSubscription).filter_by(tenant_id=tenant.id).delete() - db_session_with_containers.query(Workflow).filter_by(app_id=app.id).delete() - db_session_with_containers.query(App).filter_by(id=app.id).delete() + db_session_with_containers.execute(delete(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app.id)) + db_session_with_containers.execute(delete(WorkflowSchedulePlan).where(WorkflowSchedulePlan.app_id == app.id)) + db_session_with_containers.execute(delete(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.app_id == app.id)) + db_session_with_containers.execute(delete(WorkflowPluginTrigger).where(WorkflowPluginTrigger.app_id == app.id)) + db_session_with_containers.execute(delete(AppTrigger).where(AppTrigger.app_id == app.id)) + db_session_with_containers.execute(delete(TriggerSubscription).where(TriggerSubscription.tenant_id == tenant.id)) + db_session_with_containers.execute(delete(Workflow).where(Workflow.app_id == app.id)) + db_session_with_containers.execute(delete(App).where(App.id == app.id)) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 7539bae685..9c20118e27 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,7 +10,7 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient -from graphon.enums import BuiltinNodeTypes +from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -24,6 +24,7 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus @@ -227,7 +228,9 @@ def test_webhook_trigger_creates_trigger_log( assert response.status_code == 200 db_session_with_containers.expire_all() - logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + logs = db_session_with_containers.scalars( + select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id) + ).all() assert logs, "Webhook trigger should create trigger log" @@ -602,16 +605,18 @@ def test_schedule_trigger_creates_trigger_log( ) # Mock quota to avoid rate limiting - from enums import quota_type + from services import quota_service - monkeypatch.setattr(quota_type.QuotaType.TRIGGER, "consume", lambda _tenant_id: quota_type.unlimited()) + monkeypatch.setattr(quota_service.QuotaService, "reserve", lambda *_args, **_kwargs: quota_service.unlimited()) # Execute schedule trigger workflow_schedule_tasks.run_schedule_trigger(plan.id) # Verify WorkflowTriggerLog was created db_session_with_containers.expire_all() - logs = db_session_with_containers.query(WorkflowTriggerLog).filter_by(app_id=app_model.id).all() + logs = db_session_with_containers.scalars( + select(WorkflowTriggerLog).where(WorkflowTriggerLog.app_id == app_model.id) + ).all() assert logs, "Schedule trigger should create WorkflowTriggerLog" assert logs[0].trigger_type == AppTriggerType.TRIGGER_SCHEDULE assert logs[0].root_node_id == schedule_node_id @@ -786,11 +791,12 @@ def test_plugin_trigger_full_chain_with_db_verification( # Verify database records exist db_session_with_containers.expire_all() - plugin_triggers = ( - db_session_with_containers.query(WorkflowPluginTrigger) - .filter_by(app_id=app_model.id, node_id=plugin_node_id) - .all() - ) + plugin_triggers = db_session_with_containers.scalars( + select(WorkflowPluginTrigger).where( + WorkflowPluginTrigger.app_id == app_model.id, + WorkflowPluginTrigger.node_id == plugin_node_id, + ) + ).all() assert plugin_triggers, "WorkflowPluginTrigger record should exist" assert plugin_triggers[0].provider_id == provider_id assert plugin_triggers[0].event_name == "test_event" diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py index 19a41b6186..a5086b4c5d 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -1,12 +1,14 @@ from textwrap import dedent +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): """Test class for JavaScript code executor functionality.""" - def test_javascript_plain(self, flask_app_with_containers): + def test_javascript_plain(self, flask_app_with_containers: Flask): """Test basic JavaScript code execution with console.log output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -14,7 +16,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) assert result_message == "Hello World\n" - def test_javascript_json(self, flask_app_with_containers): + def test_javascript_json(self, flask_app_with_containers: Flask): """Test JavaScript code execution with JSON output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -25,7 +27,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) assert result == '{"Hello":"World"}\n' - def test_javascript_with_code_template(self, flask_app_with_containers): + def test_javascript_with_code_template(self, flask_app_with_containers: Flask): """Test JavaScript workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports JavascriptCodeProvider, _ = self.javascript_imports @@ -37,7 +39,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "HelloWorld"} - def test_javascript_get_runner_script(self, flask_app_with_containers): + def test_javascript_get_runner_script(self, flask_app_with_containers: Flask): """Test JavaScript template transformer runner script generation""" _, NodeJsTemplateTransformer = self.javascript_imports diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py index ddb079f00c..8b4c3c3d4a 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -1,12 +1,14 @@ import base64 +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestJinja2CodeExecutor(CodeExecutorTestMixin): """Test class for Jinja2 code executor functionality.""" - def test_jinja2(self, flask_app_with_containers): + def test_jinja2(self, flask_app_with_containers: Flask): """Test basic Jinja2 template execution with variable substitution""" CodeExecutor, CodeLanguage = self.code_executor_imports _, Jinja2TemplateTransformer = self.jinja2_imports @@ -25,7 +27,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): ) assert result == "<>Hello World<>\n" - def test_jinja2_with_code_template(self, flask_app_with_containers): + def test_jinja2_with_code_template(self, flask_app_with_containers: Flask): """Test Jinja2 workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -34,7 +36,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "Hello World"} - def test_jinja2_get_runner_script(self, flask_app_with_containers): + def test_jinja2_get_runner_script(self, flask_app_with_containers: Flask): """Test Jinja2 template transformer runner script generation""" _, Jinja2TemplateTransformer = self.jinja2_imports @@ -43,7 +45,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 - def test_jinja2_template_with_special_characters(self, flask_app_with_containers): + def test_jinja2_template_with_special_characters(self, flask_app_with_containers: Flask): """ Test that templates with special characters (quotes, newlines) render correctly. This is a regression test for issue #26818 where textarea pre-fill values diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py index 6d93df2472..0de41e1312 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -1,12 +1,14 @@ from textwrap import dedent +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestPython3CodeExecutor(CodeExecutorTestMixin): """Test class for Python3 code executor functionality.""" - def test_python3_plain(self, flask_app_with_containers): + def test_python3_plain(self, flask_app_with_containers: Flask): """Test basic Python3 code execution with print output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -14,7 +16,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) assert result == "Hello World\n" - def test_python3_json(self, flask_app_with_containers): + def test_python3_json(self, flask_app_with_containers: Flask): """Test Python3 code execution with JSON output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -25,7 +27,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) assert result == '{"Hello": "World"}\n' - def test_python3_with_code_template(self, flask_app_with_containers): + def test_python3_with_code_template(self, flask_app_with_containers: Flask): """Test Python3 workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports Python3CodeProvider, _ = self.python3_imports @@ -37,7 +39,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "HelloWorld"} - def test_python3_get_runner_script(self, flask_app_with_containers): + def test_python3_get_runner_script(self, flask_app_with_containers: Flask): """Test Python3 template transformer runner script generation""" _, Python3TemplateTransformer = self.python3_imports diff --git a/api/tests/unit_tests/commands/test_generate_swagger_specs.py b/api/tests/unit_tests/commands/test_generate_swagger_specs.py new file mode 100644 index 0000000000..e77e875081 --- /dev/null +++ b/api/tests/unit_tests/commands/test_generate_swagger_specs.py @@ -0,0 +1,37 @@ +"""Unit tests for the standalone Swagger export helper.""" + +import importlib.util +import json +import sys +from pathlib import Path + + +def _load_generate_swagger_specs_module(): + api_dir = Path(__file__).resolve().parents[3] + script_path = api_dir / "dev" / "generate_swagger_specs.py" + + spec = importlib.util.spec_from_file_location("generate_swagger_specs", script_path) + assert spec + assert spec.loader + + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) # type: ignore[attr-defined] + return module + + +def test_generate_specs_writes_console_web_and_service_swagger_files(tmp_path): + module = _load_generate_swagger_specs_module() + + written_paths = module.generate_specs(tmp_path) + + assert [path.name for path in written_paths] == [ + "console-swagger.json", + "web-swagger.json", + "service-swagger.json", + ] + + for path in written_paths: + payload = json.loads(path.read_text(encoding="utf-8")) + assert payload["swagger"] == "2.0" + assert "paths" in payload diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index d6933e2180..bad246a4bb 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -145,7 +145,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch): def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): - """Test that DB_EXTRAS options are properly merged with default timezone setting""" + """Test that DB_EXTRAS options are merged with the default timezone startup option.""" # Set environment variables monkeypatch.setenv("DB_TYPE", "postgresql") monkeypatch.setenv("DB_USERNAME", "postgres") @@ -158,15 +158,28 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch): # Create config config = DifyConfig() - # Get engine options - engine_options = config.SQLALCHEMY_ENGINE_OPTIONS - - # Verify options contains both search_path and timezone - options = engine_options["connect_args"]["options"] + options = config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"]["options"] assert "search_path=myschema" in options assert "timezone=UTC" in options +def test_db_session_timezone_override_can_disable_app_level_timezone_injection(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("DB_EXTRAS", "options=-c search_path=myschema") + monkeypatch.setenv("DB_SESSION_TIMEZONE_OVERRIDE", "") + + config = DifyConfig() + + assert config.SQLALCHEMY_ENGINE_OPTIONS["connect_args"] == { + "options": "-c search_path=myschema", + } + + def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch): os.environ.clear() @@ -223,6 +236,41 @@ def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest. _ = DifyConfig().normalized_pubsub_redis_url +def test_dify_config_exposes_redis_key_prefix_default(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "" + + +def test_dify_config_reads_redis_key_prefix_from_env(monkeypatch: pytest.MonkeyPatch): + os.environ.clear() + + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + monkeypatch.setenv("REDIS_KEY_PREFIX", "enterprise-a") + + config = DifyConfig(_env_file=None) + + assert config.REDIS_KEY_PREFIX == "enterprise-a" + + @pytest.mark.parametrize( ("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"), [ diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 55873b06a8..7174530e97 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -121,33 +121,32 @@ def _configure_session_factory(_unit_test_engine): configure_session_factory(_unit_test_engine, expire_on_commit=False) -def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account): +def setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_owner): """ - Helper to set up the mock DB execute chain for tenant/account authentication. + Helper to stub the tenant-owner execute result for service API app authentication. - This configures the mock to return (tenant, account) for the - db.session.execute(select(...).join().join().where()).one_or_none() - query used by validate_app_token decorator. + The validate_app_token decorator currently resolves the active tenant owner + via db.session.execute(select(Tenant, Account)...).one_or_none(). Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return - mock_account: Mock account object to return + mock_owner: Mock owner object to return from the execute result """ - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_owner) -def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): +def setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_tenant_account_join): """ - Helper to set up the mock DB execute chain for dataset tenant authentication. + Helper to stub the tenant-owner execute result for dataset token authentication. - This configures the mock to return (tenant, tenant_account) for the - db.session.execute(select(...).where().where().where().where()).one_or_none() - query used by validate_dataset_token decorator. + The validate_dataset_token decorator currently resolves the owner mapping via + db.session.execute(select(Tenant, TenantAccountJoin)...).one_or_none(), and + then loads the Account separately via db.session.get(...). Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return - mock_ta: Mock tenant account object to return + mock_tenant_account_join: Mock tenant-account join object to return """ - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_tenant_account_join) diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py index 9f1ff9b40f..bfa4048191 100644 --- a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -208,8 +208,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") @@ -230,8 +228,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") @@ -248,8 +244,6 @@ class TestAnnotationImportServiceValidation: csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with ( patch("services.annotation_service.current_account_with_tenant") as mock_auth, patch("services.annotation_service.pd.read_csv", side_effect=ParserError("malformed CSV")), @@ -269,8 +263,6 @@ class TestAnnotationImportServiceValidation: file = FileStorage(stream=io.BytesIO(csv_content.encode()), filename="test.csv", content_type="text/csv") - mock_db_session.query.return_value.where.return_value.first.return_value = mock_app - with patch("services.annotation_service.current_account_with_tenant") as mock_auth: mock_auth.return_value = (MagicMock(id="user_id"), "tenant_id") diff --git a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py new file mode 100644 index 0000000000..9c4678aed3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py @@ -0,0 +1,139 @@ +"""Unit tests for console app import endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from controllers.console.app import app_import as app_import_module +from services.app_dsl_service import ImportStatus + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _Result: + def __init__(self, status: ImportStatus, app_id: str | None = "app-1"): + self.status = status + self.app_id = app_id + + def model_dump(self, mode: str = "json"): + return {"status": self.status, "app_id": self.app_id} + + +def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: + features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled)) + monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features) + + +def _mock_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + return fake_session + + +class TestAppImportApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportApi() + + def test_import_post_returns_failed_status_and_rolls_back(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + + def test_import_post_returns_pending_status_and_commits(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.PENDING), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + assert status == 202 + assert response["status"] == ImportStatus.PENDING + + def test_import_post_updates_webapp_auth_when_enabled(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=True) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + update_access = MagicMock() + monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + update_access.assert_called_once_with("app-123", "private") + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + +class TestAppImportConfirmApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportConfirmApi() + + def test_import_confirm_returns_failed_status_and_rolls_back( + self, api, app, monkeypatch: pytest.MonkeyPatch + ) -> None: + method = _unwrap(api.post) + + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "confirm_import", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): + response, status = method(import_id="import-1") + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED diff --git a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py index 2ac3dc037d..80e7c41a9e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_response_models.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_response_models.py @@ -10,6 +10,8 @@ from typing import Any import pytest from flask.views import MethodView +from pydantic import ValidationError +from werkzeug.datastructures import MultiDict # kombu references MethodView as a global when importing celery/kombu pools. if not hasattr(builtins, "MethodView"): @@ -138,12 +140,15 @@ def app_models(app_module): def patch_signed_url(monkeypatch, app_module): """Ensure icon URL generation uses a deterministic helper for tests.""" - def _fake_signed_url(key: str | None) -> str | None: - if not key: + def _fake_build_icon_url(_icon_type, key: str | None) -> str | None: + if key is None: + return None + icon_type = str(_icon_type).lower() + if icon_type != "image": return None return f"signed:{key}" - monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url) + monkeypatch.setattr(app_module, "build_icon_url", _fake_build_icon_url) def _ts(hour: int = 12) -> datetime: @@ -171,6 +176,101 @@ def _dummy_workflow(): ) +def test_app_list_query_normalizes_orpc_bracket_tag_ids(app_module): + first_tag_id = "8c4ef3d1-58a1-4d94-8a1c-1c171d889e08" + second_tag_id = "3c39395b-6d1f-4030-8b17-eaa7cc85221c" + query_args = MultiDict( + [ + ("page", "1"), + ("limit", "30"), + ("tag_ids[1]", second_tag_id), + ("tag_ids[0]", first_tag_id), + ] + ) + + normalized = app_module._normalize_app_list_query_args(query_args) + query = app_module.AppListQuery.model_validate(normalized) + + assert query.tag_ids == [first_tag_id, second_tag_id] + + +def test_app_list_query_preserves_regular_query_params(app_module): + query_args = MultiDict( + [ + ("page", "2"), + ("limit", "50"), + ("mode", "chat"), + ("name", "Sales Copilot"), + ("is_created_by_me", "true"), + ] + ) + + normalized = app_module._normalize_app_list_query_args(query_args) + query = app_module.AppListQuery.model_validate(normalized) + + assert normalized == { + "page": "2", + "limit": "50", + "mode": "chat", + "name": "Sales Copilot", + "is_created_by_me": "true", + } + assert query.page == 2 + assert query.limit == 50 + assert query.mode == "chat" + assert query.name == "Sales Copilot" + assert query.is_created_by_me is True + assert query.tag_ids is None + + +def test_app_list_query_normalizes_empty_bracket_tag_ids_to_none(app_module): + query_args = MultiDict( + [ + ("tag_ids[0]", ""), + ("tag_ids[1]", " "), + ] + ) + + normalized = app_module._normalize_app_list_query_args(query_args) + query = app_module.AppListQuery.model_validate(normalized) + + assert normalized == {"tag_ids": ["", " "]} + assert query.tag_ids is None + + +def test_app_list_query_rejects_invalid_bracket_tag_id(app_module): + normalized = app_module._normalize_app_list_query_args(MultiDict([("tag_ids[0]", "not-a-uuid")])) + + with pytest.raises(ValidationError): + app_module.AppListQuery.model_validate(normalized) + + +def test_app_list_query_sorts_bracket_tag_ids_by_index(app_module): + first_tag_id = "8c4ef3d1-58a1-4d94-8a1c-1c171d889e08" + second_tag_id = "3c39395b-6d1f-4030-8b17-eaa7cc85221c" + third_tag_id = "9d5ec0f7-4f2b-4e7f-9c13-1e7a034d0eb1" + query_args = MultiDict( + [ + ("tag_ids[2]", third_tag_id), + ("tag_ids[1]", second_tag_id), + ("tag_ids[0]", first_tag_id), + ] + ) + + normalized = app_module._normalize_app_list_query_args(query_args) + query = app_module.AppListQuery.model_validate(normalized) + + assert query.tag_ids == [first_tag_id, second_tag_id, third_tag_id] + + +def test_app_list_query_rejects_flat_tag_ids(app_module): + tag_id = "8c4ef3d1-58a1-4d94-8a1c-1c171d889e08" + normalized = app_module._normalize_app_list_query_args(MultiDict([("tag_ids", tag_id)])) + + with pytest.raises(ValidationError): + app_module.AppListQuery.model_validate(normalized) + + def test_app_partial_serialization_uses_aliases(app_models): AppPartial = app_models.AppPartial created_at = _ts() diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index c52bc02420..2d218dac7e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -4,7 +4,6 @@ import io from types import SimpleNamespace import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -21,6 +20,7 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 11b3b3470d..24b7e39f73 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -33,12 +33,17 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) paginate_result = MagicMock() + paginate_result.page = 1 + paginate_result.per_page = 20 + paginate_result.total = 0 + paginate_result.has_next = False + paginate_result.items = [] monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"): response = method(app_model=SimpleNamespace(id="app-1")) - assert response is paginate_result + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -71,12 +76,17 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) paginate_result = MagicMock() + paginate_result.page = 1 + paginate_result.per_page = 20 + paginate_result.total = 0 + paginate_result.has_next = False + paginate_result.items = [] monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"): response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT)) - assert response is paginate_result + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py b/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py deleted file mode 100644 index f588ab261d..0000000000 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_read_timestamp.py +++ /dev/null @@ -1,42 +0,0 @@ -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from controllers.console.app.conversation import _get_conversation - - -def test_get_conversation_mark_read_keeps_updated_at_unchanged(): - app_model = SimpleNamespace(id="app-id") - account = SimpleNamespace(id="account-id") - conversation = MagicMock() - conversation.id = "conversation-id" - - with ( - patch( - "controllers.console.app.conversation.current_account_with_tenant", - return_value=(account, None), - autospec=True, - ), - patch( - "controllers.console.app.conversation.naive_utc_now", - return_value=datetime(2026, 2, 9, 0, 0, 0), - autospec=True, - ), - patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session, - ): - mock_session.scalar.return_value = conversation - - _get_conversation(app_model, "conversation-id") - - statement = mock_session.execute.call_args[0][0] - compiled = statement.compile() - sql_text = str(compiled).lower() - compact_sql_text = sql_text.replace(" ", "") - params = compiled.params - - assert "updated_at=current_timestamp" not in compact_sql_text - assert "updated_at=conversations.updated_at" in compact_sql_text - assert "read_at=:read_at" in compact_sql_text - assert "read_account_id=:read_account_id" in compact_sql_text - assert params["read_at"] == datetime(2026, 2, 9, 0, 0, 0) - assert params["read_account_id"] == "account-id" diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py new file mode 100644 index 0000000000..1a412aff29 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from contextlib import nullcontext +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest +from pydantic import ValidationError + +from controllers.console.app import conversation_variables as conversation_variables_module +from graphon.variables.types import SegmentType + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_get_conversation_variables_returns_paginated_response(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + created_at = datetime(2026, 1, 1, tzinfo=UTC) + updated_at = datetime(2026, 1, 2, tzinfo=UTC) + row = SimpleNamespace( + created_at=created_at, + updated_at=updated_at, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-1", + "name": "my_var", + "value_type": "string", + "value": "value", + "description": "desc", + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["page"] == 1 + assert response["limit"] == 100 + assert response["total"] == 1 + assert response["has_more"] is False + assert response["data"][0]["id"] == "var-1" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["updated_at"] == int(updated_at.timestamp()) + + +def test_get_conversation_variables_normalizes_value_type_and_value(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + row = SimpleNamespace( + created_at=None, + updated_at=None, + to_variable=lambda: SimpleNamespace( + model_dump=lambda: { + "id": "var-2", + "name": "my_var_2", + "value_type": SegmentType.INTEGER, + "value": 42, + "description": None, + } + ), + ) + session = SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [row])) + monkeypatch.setattr(conversation_variables_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + conversation_variables_module, + "sessionmaker", + lambda *_args, **_kwargs: SimpleNamespace(begin=lambda: nullcontext(session)), + ) + + with app.test_request_context( + "/console/api/apps/app-1/conversation-variables", + method="GET", + query_string={"conversation_id": "conv-1"}, + ): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response["data"][0]["value_type"] == "number" + assert response["data"][0]["value"] == "42" + + +def test_get_conversation_variables_requires_conversation_id(app) -> None: + api = conversation_variables_module.ConversationVariablesApi() + method = _unwrap(api.get) + + with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"): + with pytest.raises(ValidationError): + method(app_model=SimpleNamespace(id="app-1")) diff --git a/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py new file mode 100644 index 0000000000..1af15d8dc6 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_mcp_server_response.py @@ -0,0 +1,138 @@ +import datetime +from types import SimpleNamespace +from unittest.mock import PropertyMock, patch + +from flask import Flask + +from controllers.console import console_ns +from controllers.console.app.mcp_server import AppMCPServerController, AppMCPServerResponse + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class _ValidatedResponse: + def __init__(self, payload): + self._payload = payload + + def model_dump(self, mode="json"): + return self._payload + + +class TestAppMCPServerResponse: + def test_parameters_json_string_parsed(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": '{"key": "value"}', + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == {"key": "value"} + + def test_parameters_invalid_json_returns_original(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": "not-valid-json", + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == "not-valid-json" + + def test_parameters_dict_passthrough(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": {"already": "parsed"}, + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == {"already": "parsed"} + + def test_parameters_json_array_parsed(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": '["a", "b"]', + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.parameters == ["a", "b"] + + def test_timestamps_normalized(self): + dt = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": {}, + "created_at": dt, + "updated_at": dt, + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.created_at == int(dt.timestamp()) + assert resp.updated_at == int(dt.timestamp()) + + def test_timestamps_none(self): + data = { + "id": "s1", + "name": "test", + "server_code": "code", + "description": "desc", + "status": "active", + "parameters": {}, + } + resp = AppMCPServerResponse.model_validate(data) + assert resp.created_at is None + assert resp.updated_at is None + + +class TestAppMCPServerController: + def test_get_returns_empty_dict_when_server_missing(self): + api = AppMCPServerController() + method = unwrap(api.get) + + with patch("controllers.console.app.mcp_server.db.session.scalar", return_value=None): + response = method(api, app_model=SimpleNamespace(id="app-1")) + + assert response == {} + + def test_post_returns_201(self): + api = AppMCPServerController() + method = unwrap(api.post) + payload = {"parameters": {"timeout": 30}} + app = Flask(__name__) + app.config["TESTING"] = True + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch("controllers.console.app.mcp_server.current_account_with_tenant", return_value=(None, "tenant-1")), + patch("controllers.console.app.mcp_server.db.session.add"), + patch("controllers.console.app.mcp_server.db.session.commit"), + patch("controllers.console.app.mcp_server.AppMCPServer.generate_server_code", return_value="server-code"), + patch( + "controllers.console.app.mcp_server.AppMCPServerResponse.model_validate", + return_value=_ValidatedResponse({"id": "server-1"}), + ), + ): + response, status_code = method( + api, app_model=SimpleNamespace(id="app-1", name="Demo App", description="App description") + ) + + assert response == {"id": "server-1"} + assert status_code == 201 diff --git a/api/tests/unit_tests/controllers/console/app/test_message_api.py b/api/tests/unit_tests/controllers/console/app/test_message_api.py index a76e958829..c984dbef5d 100644 --- a/api/tests/unit_tests/controllers/console/app/test_message_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_message_api.py @@ -1,5 +1,7 @@ from __future__ import annotations +from datetime import UTC, datetime + import pytest from controllers.console.app import message as message_module @@ -120,3 +122,24 @@ def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> N response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"]) assert len(response.data) == 2 assert response.data[0] == "What is AI?" + + +def test_message_detail_response_normalizes_aliases_and_timestamp(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageDetailResponse normalizes alias fields and datetime timestamps.""" + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = message_module.MessageDetailResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "conversation_id": "550e8400-e29b-41d4-a716-446655440001", + "inputs": {"foo": "bar"}, + "query": "hello", + "re_sign_file_url_answer": "world", + "from_source": "user", + "status": "normal", + "created_at": created_at, + "message_metadata_dict": {"token_usage": 3}, + } + ) + assert response.answer == "world" + assert response.metadata == {"token_usage": 3} + assert response.created_at == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 3607636880..7c470eb9a8 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -1,15 +1,16 @@ from __future__ import annotations +import json from datetime import datetime from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.file import File, FileTransferMethod, FileType from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from graphon.file import File, FileTransferMethod, FileType def _unwrap(func): @@ -30,7 +31,7 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: file_list = [ File( tenant_id="t1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="http://u", ) @@ -258,6 +259,63 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( assert exc.value.description == "invalid workflow graph" +def test_get_published_workflows_marshals_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = workflow_module.PublishedAllWorkflowApi() + handler = _unwrap(api.get) + + session_state = {"open": False} + + class _SessionContext: + def __enter__(self): + session_state["open"] = True + return object() + + def __exit__(self, exc_type, exc, tb): + session_state["open"] = False + return False + + class _SessionMaker: + def begin(self): + return _SessionContext() + + class _Workflow: + @property + def id(self): + assert session_state["open"] is True + return "w1" + + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker()) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + get_all_published_workflow=lambda **_kwargs: ([_Workflow()], False), + ), + ) + + def _fake_marshal(items, fields): + assert session_state["open"] is True + return [{"id": item.id} for item in items] + + monkeypatch.setattr(workflow_module, "marshal", _fake_marshal) + + with app.test_request_context( + "/apps/app/workflows", + method="GET", + query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"}, + ): + response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1")) + + assert response == { + "items": [{"id": "w1"}], + "page": 1, + "limit": 10, + "has_more": False, + } + + def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None) @@ -290,3 +348,120 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk ): with pytest.raises(NotFound): handler(api, app_model=SimpleNamespace(id="app")) + + +def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: pytest.MonkeyPatch) -> None: + app_id_1 = "11111111-1111-1111-1111-111111111111" + app_id_2 = "22222222-2222-2222-2222-222222222222" + signed_avatar_url = "https://files.example.com/signed/avatar-1" + sign_avatar = Mock(return_value=signed_avatar_url) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=lambda app_ids, tenant_id: {app_id_1}), + ) + monkeypatch.setattr(workflow_module.file_helpers, "get_signed_file_url", sign_avatar) + + redis_pipeline = Mock() + redis_pipeline.execute.return_value = [ + { + b"sid-1": json.dumps( + { + "user_id": "u-1", + "username": "Alice", + "avatar": "avatar-file-id", + "sid": "sid-1", + } + ) + } + ] + workflow_module.redis_client.pipeline.return_value = redis_pipeline + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/workflows/online-users", + method="POST", + json={"app_ids": [app_id_1, app_id_2]}, + ): + response = handler(api) + + assert response == { + "data": [ + { + "app_id": app_id_1, + "users": [ + { + "user_id": "u-1", + "username": "Alice", + "avatar": signed_avatar_url, + "sid": "sid-1", + } + ], + } + ] + } + workflow_module.redis_client.pipeline.assert_called_once_with(transaction=False) + redis_pipeline.hgetall.assert_called_once_with(f"{workflow_module.WORKFLOW_ONLINE_USERS_PREFIX}{app_id_1}") + redis_pipeline.execute.assert_called_once_with() + sign_avatar.assert_called_once_with("avatar-file-id") + + +def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.MonkeyPatch) -> None: + app_ids = [f"wf-{index}" for index in range(workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE + 1)] + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=lambda app_ids, tenant_id: set(app_ids)), + ) + + first_pipeline = Mock() + first_pipeline.execute.return_value = [{} for _ in range(workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE)] + second_pipeline = Mock() + second_pipeline.execute.return_value = [{}] + workflow_module.redis_client.pipeline.side_effect = [first_pipeline, second_pipeline] + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/workflows/online-users", + method="POST", + json={"app_ids": app_ids}, + ): + response = handler(api) + + assert len(response["data"]) == len(app_ids) + assert workflow_module.redis_client.pipeline.call_count == 2 + assert first_pipeline.hgetall.call_count == workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE + assert second_pipeline.hgetall.call_count == 1 + + +def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) + accessible_app_ids = Mock(return_value=set()) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(get_accessible_app_ids=accessible_app_ids), + ) + + excessive_ids = [f"wf-{index}" for index in range(workflow_module.MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS + 1)] + + api = workflow_module.WorkflowOnlineUsersApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/workflows/online-users", + method="POST", + json={"app_ids": excessive_ids}, + ): + with pytest.raises(HTTPException) as exc: + handler(api) + + assert exc.value.code == 400 + assert "Maximum" in exc.value.description + accessible_app_ids.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py new file mode 100644 index 0000000000..a9853f25b0 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_app_log_api.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from controllers.console.app import workflow_app_log as workflow_app_log_module +from graphon.enums import WorkflowExecutionStatus + + +def test_workflow_app_log_query_parses_bool_and_datetime(): + query = workflow_app_log_module.WorkflowAppLogQuery.model_validate( + { + "detail": "true", + "created_at__before": "2026-01-02T03:04:05Z", + "page": "2", + "limit": "10", + } + ) + + assert query.detail is True + assert query.created_at__before == datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + assert query.page == 2 + assert query.limit == 10 + + +def test_workflow_app_log_pagination_response_normalizes_nested_fields(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = workflow_app_log_module.WorkflowAppLogPaginationResponse.model_validate( + { + "page": 1, + "limit": 20, + "total": 1, + "has_more": False, + "data": [ + { + "id": "log-1", + "workflow_run": { + "id": "run-1", + "status": WorkflowExecutionStatus.SUCCEEDED, + "created_at": created_at, + "finished_at": created_at, + }, + "details": {"trigger_metadata": {}}, + "created_by_account": {"id": "acc-1", "name": "acc", "email": "acc@example.com"}, + "created_at": created_at, + } + ], + } + ).model_dump(mode="json") + + assert response["data"][0]["workflow_run"]["status"] == "succeeded" + assert response["data"][0]["workflow_run"]["created_at"] == int(created_at.timestamp()) + assert response["data"][0]["created_at"] == int(created_at.timestamp()) + + +def test_workflow_archived_log_pagination_response_normalizes_nested_fields(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = workflow_app_log_module.WorkflowArchivedLogPaginationResponse.model_validate( + { + "page": 1, + "limit": 20, + "total": 1, + "has_more": False, + "data": [ + { + "id": "archived-1", + "workflow_run": { + "id": "run-1", + "status": WorkflowExecutionStatus.FAILED, + }, + "trigger_metadata": {"type": "trigger-plugin"}, + "created_by_end_user": { + "id": "eu-1", + "type": "anonymous", + "is_anonymous": True, + "session_id": "session-1", + }, + "created_at": created_at, + } + ], + } + ).model_dump(mode="json") + + assert response["data"][0]["workflow_run"]["status"] == "failed" + assert response["data"][0]["created_at"] == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py new file mode 100644 index 0000000000..85afcf0e60 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from contextlib import nullcontext +from dataclasses import dataclass +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console import wraps as console_wraps +from controllers.console.app import workflow_comment as workflow_comment_module +from controllers.console.app import wraps as app_wraps +from libs import login as login_lib +from models.account import Account, AccountStatus, TenantAccountRole + + +def _make_account(role: TenantAccountRole) -> Account: + account = Account(name="tester", email="tester@example.com") + account.status = AccountStatus.ACTIVE + account.role = role + account.id = "account-123" # type: ignore[assignment] + account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined] + account._get_current_object = lambda: account # type: ignore[attr-defined] + return account + + +def _make_app() -> SimpleNamespace: + return SimpleNamespace(id="app-123", tenant_id="tenant-123", status="normal", mode="workflow") + + +def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None: + monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True) + monkeypatch.setattr(login_lib, "current_user", account) + monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None) + monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") + monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) + monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) + monkeypatch.setattr(workflow_comment_module, "current_user", account) + + +def _patch_write_services(monkeypatch: pytest.MonkeyPatch) -> None: + for method_name in ( + "create_comment", + "update_comment", + "delete_comment", + "resolve_comment", + "validate_comment_access", + "create_reply", + "update_reply", + "delete_reply", + ): + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, method_name, MagicMock()) + + +def _patch_payload(payload: dict[str, object] | None): + if payload is None: + return nullcontext() + return patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + +@dataclass(frozen=True) +class WriteCase: + resource_cls: type + method_name: str + path: str + kwargs: dict[str, str] + payload: dict[str, object] | None = None + + +@pytest.mark.parametrize( + "case", + [ + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentListApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments", + kwargs={"app_id": "app-123"}, + payload={"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentDetailApi, + method_name="delete", + path="/console/api/apps/app-123/workflow/comments/comment-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentResolveApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/resolve", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyApi, + method_name="post", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies", + kwargs={"app_id": "app-123", "comment_id": "comment-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="put", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + payload={"content": "reply", "mentioned_user_ids": []}, + ), + WriteCase( + resource_cls=workflow_comment_module.WorkflowCommentReplyDetailApi, + method_name="delete", + path="/console/api/apps/app-123/workflow/comments/comment-1/replies/reply-1", + kwargs={"app_id": "app-123", "comment_id": "comment-1", "reply_id": "reply-1"}, + ), + ], +) +def test_write_endpoints_require_edit_permission(app: Flask, monkeypatch: pytest.MonkeyPatch, case: WriteCase) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.NORMAL) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + _patch_write_services(monkeypatch) + + with app.test_request_context(case.path, method=case.method_name.upper(), json=case.payload): + with _patch_payload(case.payload): + handler = getattr(case.resource_cls(), case.method_name) + with pytest.raises(Forbidden): + handler(**case.kwargs) + + +def test_create_comment_allows_editor(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + create_comment_mock = MagicMock(return_value={"id": "comment-1"}) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "create_comment", create_comment_mock) + payload = {"content": "hello", "position_x": 1.0, "position_y": 2.0, "mentioned_user_ids": []} + + with app.test_request_context("/console/api/apps/app-123/workflow/comments", method="POST", json=payload): + with _patch_payload(payload): + result = workflow_comment_module.WorkflowCommentListApi().post(app_id="app-123") + + if isinstance(result, tuple): + response = result[0] + else: + response = result + assert response["id"] == "comment-1" + create_comment_mock.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + created_by="account-123", + content="hello", + position_x=1.0, + position_y=2.0, + mentioned_user_ids=[], + ) + + +def test_update_comment_omits_mentions_when_payload_does_not_include_them( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: + app.config.setdefault("RESTX_MASK_HEADER", "X-Fields") + account = _make_account(TenantAccountRole.EDITOR) + app_model = _make_app() + _patch_console_guards(monkeypatch, account, app_model) + + update_comment_mock = MagicMock(return_value={"id": "comment-1", "updated_at": datetime(2024, 1, 1, 12, 0, 0)}) + monkeypatch.setattr(workflow_comment_module.WorkflowCommentService, "update_comment", update_comment_mock) + payload = {"content": "hello", "position_x": 10.0, "position_y": 20.0} + + with app.test_request_context("/console/api/apps/app-123/workflow/comments/comment-1", method="PUT", json=payload): + with _patch_payload(payload): + workflow_comment_module.WorkflowCommentDetailApi().put(app_id="app-123", comment_id="comment-1") + + update_comment_mock.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + comment_id="comment-1", + user_id="account-123", + content="hello", + position_x=10.0, + position_y=20.0, + mentioned_user_ids=None, + ) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index e11102acb1..c4a8148446 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py new file mode 100644 index 0000000000..5363aa154f --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_trigger_api.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace + +from controllers.console.app import workflow_trigger as workflow_trigger_module + + +def test_parser_models_validate(): + parser = workflow_trigger_module.Parser(node_id="node-1") + enable_parser = workflow_trigger_module.ParserEnable( + trigger_id="550e8400-e29b-41d4-a716-446655440000", enable_trigger=True + ) + + assert parser.node_id == "node-1" + assert enable_parser.enable_trigger is True + + +def test_workflow_trigger_response_serializes_datetime(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + trigger = SimpleNamespace( + id="trigger-1", + trigger_type="trigger-plugin", + title="Trigger", + node_id="node-1", + provider_name="provider", + icon="https://example.com/icon", + status="enabled", + created_at=created_at, + updated_at=created_at, + ) + + payload = workflow_trigger_module.WorkflowTriggerResponse.model_validate(trigger, from_attributes=True).model_dump( + mode="json" + ) + assert payload["id"] == "trigger-1" + assert payload["created_at"] == "2026-01-02T03:04:05Z" + assert payload["updated_at"] == "2026-01-02T03:04:05Z" + + +def test_webhook_trigger_response_serializes_datetime(): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + webhook = { + "id": "webhook-1", + "webhook_id": "whk-1", + "webhook_url": "https://example.com/hook", + "webhook_debug_url": "https://example.com/hook/debug", + "node_id": "node-1", + "created_at": created_at, + } + + payload = workflow_trigger_module.WebhookTriggerResponse.model_validate(webhook).model_dump(mode="json") + assert payload["webhook_id"] == "whk-1" + assert payload["created_at"] == "2026-01-02T03:04:05Z" diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index 740da1f1df..62fa82e339 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -1,11 +1,10 @@ import uuid from collections import OrderedDict from typing import Any, NamedTuple -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask_restx import marshal -from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -16,6 +15,7 @@ from controllers.console.app.workflow_draft_variable import ( ) from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment +from graphon.variables.types import SegmentType from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile @@ -29,15 +29,18 @@ class TestWorkflowDraftVariableFields: def test_serialize_full_content(self): """Test that _serialize_full_content uses pre-loaded relationships.""" # Create mock objects with relationships pre-loaded - mock_variable_file = MagicMock(spec=WorkflowDraftVariableFile) - mock_variable_file.size = 100000 - mock_variable_file.length = 50 - mock_variable_file.value_type = SegmentType.OBJECT - mock_variable_file.upload_file_id = "test-upload-file-id" - - mock_variable = MagicMock(spec=WorkflowDraftVariable) - mock_variable.file_id = "test-file-id" - mock_variable.variable_file = mock_variable_file + mock_variable = WorkflowDraftVariable( + file_id="test-file-id", + variable_file=WorkflowDraftVariableFile( + size=100000, + length=50, + value_type=SegmentType.OBJECT, + upload_file_id="test-upload-file-id", + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), + user_id=str(uuid.uuid4()), + ), + ) # Mock the file helpers with patch("controllers.console.app.workflow_draft_variable.file_helpers", autospec=True) as mock_file_helpers: @@ -84,7 +87,7 @@ class TestWorkflowDraftVariableFields: expected_without_value: OrderedDict[str, Any] = OrderedDict( { - "id": str(conv_var.id), + "id": conv_var.id, "type": conv_var.get_variable_type().value, "name": "conv_var", "description": "", @@ -117,7 +120,7 @@ class TestWorkflowDraftVariableFields: expected_without_value = OrderedDict( { - "id": str(sys_var.id), + "id": sys_var.id, "type": sys_var.get_variable_type().value, "name": "sys_var", "description": "", @@ -149,7 +152,7 @@ class TestWorkflowDraftVariableFields: expected_without_value: OrderedDict[str, Any] = OrderedDict( { - "id": str(node_var.id), + "id": node_var.id, "type": node_var.get_variable_type().value, "name": "node_var", "description": "", @@ -180,19 +183,22 @@ class TestWorkflowDraftVariableFields: node_var.id = str(uuid.uuid4()) node_var.last_edited_at = naive_utc_now() variable_file = WorkflowDraftVariableFile( - id=str(uuidv7()), upload_file_id=str(uuid.uuid4()), size=1024, length=10, value_type=SegmentType.ARRAY_STRING, + tenant_id=str(uuidv7()), + app_id=str(uuidv7()), + user_id=str(uuidv7()), ) + variable_file.id = str(uuidv7()) node_var.variable_file = variable_file node_var.file_id = variable_file.id expected_without_value: OrderedDict[str, Any] = OrderedDict( { - "id": str(node_var.id), - "type": node_var.get_variable_type().value, + "id": node_var.id, + "type": node_var.get_variable_type(), "name": "node_var", "description": "", "selector": ["test_node", "node_var"], @@ -235,7 +241,7 @@ class TestWorkflowDraftVariableList: node_var.id = str(uuid.uuid4()) node_var_dict = OrderedDict( { - "id": str(node_var.id), + "id": node_var.id, "type": node_var.get_variable_type().value, "name": "test_var", "description": "", @@ -314,8 +320,8 @@ def test_workflow_file_variable_with_signed_url(): # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", filename="test.jpg", @@ -370,8 +376,8 @@ def test_workflow_file_variable_remote_url(): # Create a File object with REMOTE_URL transfer method test_file = File( - id="test_file_id", - type=FileType.IMAGE, + file_id="test_file_id", + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", filename="test.jpg", diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index d3e864a75a..0fb0ebc330 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -67,7 +67,7 @@ class TestActivateCheckApi: assert response["data"]["email"] == "invitee@example.com" @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_invalid_invitation_token(self, mock_get_invitation, app): + def test_check_invalid_invitation_token(self, mock_get_invitation, app: Flask): """ Test checking invalid invitation token. @@ -185,7 +185,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, ): @@ -227,7 +227,7 @@ class TestActivateApi: mock_db.session.commit.assert_called_once() @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_activation_with_invalid_token(self, mock_get_invitation, app): + def test_activation_with_invalid_token(self, mock_get_invitation, app: Flask): """ Test account activation with invalid token. @@ -263,7 +263,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, ): @@ -312,7 +312,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, language, @@ -358,7 +358,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, ): """ @@ -398,7 +398,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, ): """ @@ -438,7 +438,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, ): diff --git a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py index cb4fe40944..17bee94c52 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py +++ b/api/tests/unit_tests/controllers/console/auth/test_authentication_security.py @@ -43,7 +43,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = True # Act @@ -76,7 +75,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Wrong password") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists # Act with self.app.test_request_context( @@ -109,7 +107,6 @@ class TestAuthenticationSecurity: mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = services.errors.account.AccountPasswordError("Invalid email or password.") - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_features.return_value.is_allow_register = False # Act @@ -135,7 +132,6 @@ class TestAuthenticationSecurity: def test_reset_password_with_existing_account(self, mock_send_email, mock_get_user, mock_features, mock_db): """Test that reset password returns success with token for existing accounts.""" # Mock the setup check - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists # Test with existing account mock_get_user.return_value = MagicMock(email="existing@example.com") diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index 9929a71120..102af9b250 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -65,7 +65,6 @@ class TestEmailCodeLoginSendEmailApi: - IP rate limiting is checked """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = mock_account mock_send_email.return_value = "email_token_123" @@ -98,7 +97,6 @@ class TestEmailCodeLoginSendEmailApi: - Registration is allowed by system features """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = None mock_get_features.return_value.is_allow_register = True @@ -130,7 +128,6 @@ class TestEmailCodeLoginSendEmailApi: - Registration is blocked by system features """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = None mock_get_features.return_value.is_allow_register = False @@ -143,7 +140,7 @@ class TestEmailCodeLoginSendEmailApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit") - def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app: Flask): """ Test email code sending blocked by IP rate limit. @@ -152,7 +149,6 @@ class TestEmailCodeLoginSendEmailApi: - Prevents spam and abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = True # Act & Assert @@ -164,7 +160,7 @@ class TestEmailCodeLoginSendEmailApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.login.AccountService.get_user_through_email") - def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app): + def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app: Flask): """ Test email code sending to frozen account. @@ -172,7 +168,6 @@ class TestEmailCodeLoginSendEmailApi: - AccountInFreezeError is raised for frozen accounts """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.side_effect = AccountRegisterError("Account frozen") @@ -200,7 +195,7 @@ class TestEmailCodeLoginSendEmailApi: mock_get_user, mock_is_ip_limit, mock_db, - app, + app: Flask, mock_account, language_input, expected_language, @@ -213,7 +208,6 @@ class TestEmailCodeLoginSendEmailApi: - Defaults to en-US when not specified """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_ip_limit.return_value = False mock_get_user.return_value = mock_account mock_send_email.return_value = "token" @@ -273,7 +267,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -286,7 +280,6 @@ class TestEmailCodeLoginApi: - User is logged in with token pair """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [MagicMock()] @@ -322,7 +315,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -335,7 +328,6 @@ class TestEmailCodeLoginApi: - User is logged in after account creation """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "newuser@example.com", "code": "123456"} mock_get_user.return_value = None mock_create_account.return_value = mock_account @@ -361,7 +353,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app): + def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app: Flask): """ Test email code login with invalid token. @@ -369,7 +361,6 @@ class TestEmailCodeLoginApi: - InvalidTokenError is raised for invalid/expired tokens """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = None # Act & Assert @@ -384,7 +375,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app): + def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app: Flask): """ Test email code login with mismatched email. @@ -392,7 +383,6 @@ class TestEmailCodeLoginApi: - InvalidEmailError is raised when email doesn't match token """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "original@example.com", "code": "123456"} # Act & Assert @@ -407,7 +397,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app): + def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app: Flask): """ Test email code login with incorrect code. @@ -415,7 +405,6 @@ class TestEmailCodeLoginApi: - EmailCodeError is raised for wrong verification code """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} # Act & Assert @@ -442,7 +431,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, ): """ @@ -453,7 +442,6 @@ class TestEmailCodeLoginApi: - User is added as owner of new workspace """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] @@ -486,7 +474,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, ): """ @@ -496,7 +484,6 @@ class TestEmailCodeLoginApi: - WorkspacesLimitExceeded is raised when limit reached """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] @@ -528,7 +515,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, ): """ @@ -538,7 +525,6 @@ class TestEmailCodeLoginApi: - NotAllowedCreateWorkspace is raised when creation disabled """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_get_data.return_value = {"email": "test@example.com", "code": "123456"} mock_get_user.return_value = mock_account mock_get_tenants.return_value = [] diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 560971206f..ace2ce5706 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -9,23 +9,25 @@ This module tests the core authentication endpoints including: """ import base64 -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from flask import Flask from flask_restx import Api +from werkzeug.exceptions import Unauthorized from controllers.console.auth.error import ( AuthenticationFailedError, EmailPasswordLoginLimitError, InvalidEmailError, ) -from controllers.console.auth.login import LoginApi, LogoutApi +from controllers.console.auth.login import EmailCodeLoginApi, LoginApi, LogoutApi from controllers.console.error import ( AccountBannedError, AccountInFreezeError, WorkspacesLimitExceeded, ) +from services.entities.auth_entities import LoginFailureReason from services.errors.account import AccountLoginError, AccountPasswordError @@ -34,6 +36,11 @@ def encode_password(password: str) -> str: return base64.b64encode(password.encode("utf-8")).decode() +def encode_code(code: str) -> str: + """Helper to encode verification code as Base64 for testing.""" + return base64.b64encode(code.encode("utf-8")).decode() + + class TestLoginApi: """Test cases for the LoginApi endpoint.""" @@ -45,12 +52,12 @@ class TestLoginApi: return app @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return Api(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client.""" api.add_resource(LoginApi, "/login") return app.test_client() @@ -90,7 +97,7 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -103,7 +110,6 @@ class TestLoginApi: - Rate limit is reset after successful login """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.return_value = mock_account @@ -135,14 +141,14 @@ class TestLoginApi: @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit") def test_successful_login_with_valid_invitation( self, - mock_reset_rate_limit, + mock_reset_rate_limit: Mock, mock_login, mock_get_tenants, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -155,7 +161,6 @@ class TestLoginApi: - Authentication proceeds with invitation token """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = {"data": {"email": "test@example.com"}} mock_authenticate.return_value = mock_account @@ -183,7 +188,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") - def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): + def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask): """ Test login rejection when rate limit is exceeded. @@ -192,22 +197,26 @@ class TestLoginApi: - EmailPasswordLoginLimitError is raised when limit exceeded """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = True mock_get_invitation.return_value = None # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} - ): - login_api = LoginApi() - with pytest.raises(EmailPasswordLoginLimitError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} + ): + login_api = LoginApi() + with pytest.raises(EmailPasswordLoginLimitError): + login_api.post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "test@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.LOGIN_RATE_LIMITED @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True) @patch("controllers.console.auth.login.BillingService.is_email_in_freeze") - def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app): + def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask): """ Test login rejection for frozen accounts. @@ -216,16 +225,20 @@ class TestLoginApi: - AccountInFreezeError is raised for frozen accounts """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_frozen.return_value = True # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} - ): - login_api = LoginApi() - with pytest.raises(AccountInFreezeError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} + ): + login_api = LoginApi() + with pytest.raises(AccountInFreezeError): + login_api.post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "frozen@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_IN_FREEZE @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -240,7 +253,7 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, ): """ Test login failure with invalid credentials. @@ -251,20 +264,25 @@ class TestLoginApi: - Generic error message prevents user enumeration """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = AccountPasswordError("Invalid password") # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": encode_password("WrongPass123!")} - ): - login_api = LoginApi() - with pytest.raises(AuthenticationFailedError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", + method="POST", + json={"email": "test@example.com", "password": encode_password("WrongPass123!")}, + ): + login_api = LoginApi() + with pytest.raises(AuthenticationFailedError): + login_api.post() mock_add_rate_limit.assert_called_once_with("test@example.com") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "test@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -272,7 +290,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") def test_login_fails_for_banned_account( - self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app + self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask ): """ Test login rejection for banned accounts. @@ -282,18 +300,24 @@ class TestLoginApi: - Login is prevented even with valid credentials """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = AccountLoginError("Account is banned") # Act & Assert - with app.test_request_context( - "/login", method="POST", json={"email": "banned@example.com", "password": encode_password("ValidPass123!")} - ): - login_api = LoginApi() - with pytest.raises(AccountBannedError): - login_api.post() + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/login", + method="POST", + json={"email": "banned@example.com", "password": encode_password("ValidPass123!")}, + ): + login_api = LoginApi() + with pytest.raises(AccountBannedError): + login_api.post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "banned@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -304,14 +328,14 @@ class TestLoginApi: @patch("controllers.console.auth.login.FeatureService.get_system_features") def test_login_fails_when_no_workspace_and_limit_exceeded( self, - mock_get_features, - mock_get_tenants, - mock_authenticate, - mock_get_invitation, - mock_is_rate_limit, - mock_db, - app, - mock_account, + mock_get_features: MagicMock, + mock_get_tenants: MagicMock, + mock_authenticate: MagicMock, + mock_get_invitation: MagicMock, + mock_is_rate_limit: MagicMock, + mock_db: MagicMock, + app: Flask, + mock_account: MagicMock, ): """ Test login failure when user has no workspace and workspace limit exceeded. @@ -321,7 +345,6 @@ class TestLoginApi: - User cannot login without an assigned workspace """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.return_value = mock_account @@ -344,7 +367,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") - def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): + def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask): """ Test login failure when invitation email doesn't match login email. @@ -353,7 +376,6 @@ class TestLoginApi: - Security check prevents invitation token abuse """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = {"data": {"email": "invited@example.com"}} @@ -390,12 +412,11 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): """Test that login retries with lowercase email when uppercase lookup fails.""" - mock_db.session.query.return_value.first.return_value = MagicMock() mock_is_rate_limit.return_value = False mock_get_invitation.return_value = None mock_authenticate.side_effect = [AccountPasswordError("Invalid"), mock_account] @@ -417,6 +438,35 @@ class TestLoginApi: mock_add_rate_limit.assert_not_called() mock_reset_rate_limit.assert_called_once_with("upper@example.com") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") + @patch("controllers.console.auth.login.AccountService.revoke_email_code_login_token") + @patch("controllers.console.auth.login._get_account_with_case_fallback") + def test_email_code_login_logs_banned_account( + self, + mock_get_account, + mock_revoke_token, + mock_get_token_data, + mock_db, + app: Flask, + ): + mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} + mock_get_account.side_effect = Unauthorized("Account is banned.") + + with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/email-code-login/validity", + method="POST", + json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + with pytest.raises(AccountBannedError): + EmailCodeLoginApi().post() + + mock_revoke_token.assert_called_once_with("token-123") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED + class TestLogoutApi: """Test cases for the LogoutApi endpoint.""" @@ -441,7 +491,7 @@ class TestLogoutApi: @patch("controllers.console.auth.login.AccountService.logout") @patch("controllers.console.auth.login.flask_login.logout_user") def test_successful_logout( - self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app, mock_account + self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app: Flask, mock_account ): """ Test successful logout flow. @@ -453,7 +503,6 @@ class TestLogoutApi: - Success response is returned """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() mock_current_account.return_value = (mock_account, MagicMock()) # Act @@ -469,7 +518,7 @@ class TestLogoutApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.current_account_with_tenant") @patch("controllers.console.auth.login.flask_login") - def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app): + def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app: Flask): """ Test logout for anonymous (not logged in) user. @@ -479,7 +528,6 @@ class TestLogoutApi: - Success response is returned """ # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() # Create a mock anonymous user that will pass isinstance check anonymous_user = MagicMock() mock_flask_login.AnonymousUserMixin = type("AnonymousUserMixin", (), {}) diff --git a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py index d010f60866..22974ca416 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py +++ b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py @@ -28,12 +28,12 @@ class TestRefreshTokenApi: return app @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return Api(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client.""" api.add_resource(RefreshTokenApi, "/refresh-token") return app.test_client() @@ -74,7 +74,7 @@ class TestRefreshTokenApi: assert response.json["result"] == "success" @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) - def test_refresh_fails_without_token(self, mock_extract_token, app): + def test_refresh_fails_without_token(self, mock_extract_token, app: Flask): """ Test token refresh failure when no refresh token provided. @@ -98,7 +98,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app): + def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app: Flask): """ Test token refresh failure with invalid refresh token. @@ -123,7 +123,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app): + def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app: Flask): """ Test token refresh failure with expired refresh token. @@ -148,7 +148,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app): + def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app: Flask): """ Test token refresh with empty string token. diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py index c80758c857..defa9064fd 100644 --- a/api/tests/unit_tests/controllers/console/billing/test_billing.py +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -46,11 +46,10 @@ class TestPartnerTenants: patch("libs.login.dify_config.LOGIN_DISABLED", False), patch("libs.login.check_csrf_token") as mock_csrf, ): - mock_db.session.query.return_value.first.return_value = MagicMock() # Mock setup exists mock_csrf.return_value = None yield {"db": mock_db, "csrf": mock_csrf} - def test_put_success(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_success(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test successful partner tenants bindings sync.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -80,7 +79,7 @@ class TestPartnerTenants: mock_account.id, "partner-key-123", click_id ) - def test_put_invalid_partner_key_base64(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_invalid_partner_key_base64(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that invalid base64 partner_key raises BadRequest.""" # Arrange invalid_partner_key = "invalid-base64-!@#$" @@ -105,7 +104,7 @@ class TestPartnerTenants: resource.put(invalid_partner_key) assert "Invalid partner_key" in str(exc_info.value) - def test_put_missing_click_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_missing_click_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that missing click_id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -129,7 +128,9 @@ class TestPartnerTenants: with pytest.raises(BadRequest): resource.put(partner_key_encoded) - def test_put_billing_service_json_decode_error(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_billing_service_json_decode_error( + self, app: Flask, mock_account, mock_billing_service, mock_decorators + ): """Test handling of billing service JSON decode error. When billing service returns non-200 status code with invalid JSON response, @@ -175,7 +176,7 @@ class TestPartnerTenants: assert isinstance(exc_info.value, json.JSONDecodeError) assert "Expecting value" in str(exc_info.value) - def test_put_empty_click_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_click_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty click_id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -200,7 +201,7 @@ class TestPartnerTenants: resource.put(partner_key_encoded) assert "Invalid partner information" in str(exc_info.value) - def test_put_empty_partner_key_after_decode(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_partner_key_after_decode(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty partner_key after decode raises BadRequest.""" # Arrange # Base64 encode an empty string @@ -226,7 +227,7 @@ class TestPartnerTenants: resource.put(empty_partner_key_encoded) assert "Invalid partner information" in str(exc_info.value) - def test_put_empty_user_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_user_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty user id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 9c9f8da87c..9c5b5ec256 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from flask import Flask from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -18,6 +18,7 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService @@ -29,7 +30,7 @@ def unwrap(func): class TestDatasourcePluginOAuthAuthorizationUrl: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() method = unwrap(api.get) @@ -61,7 +62,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: assert response.status_code == 200 - def test_get_no_oauth_config(self, app): + def test_get_no_oauth_config(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() method = unwrap(api.get) @@ -80,7 +81,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: with pytest.raises(ValueError): method(api, "notion") - def test_get_without_credential_id_sets_cookie(self, app): + def test_get_without_credential_id_sets_cookie(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() method = unwrap(api.get) @@ -115,7 +116,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: class TestDatasourceOAuthCallback: - def test_callback_success_new_credential(self, app): + def test_callback_success_new_credential(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -157,7 +158,7 @@ class TestDatasourceOAuthCallback: assert response.status_code == 302 - def test_callback_missing_context(self, app): + def test_callback_missing_context(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -165,7 +166,7 @@ class TestDatasourceOAuthCallback: with pytest.raises(Forbidden): method(api, "notion") - def test_callback_invalid_context(self, app): + def test_callback_invalid_context(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -180,7 +181,7 @@ class TestDatasourceOAuthCallback: with pytest.raises(Forbidden): method(api, "notion") - def test_callback_oauth_config_not_found(self, app): + def test_callback_oauth_config_not_found(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -202,7 +203,7 @@ class TestDatasourceOAuthCallback: with pytest.raises(NotFound): method(api, "notion") - def test_callback_reauthorize_existing_credential(self, app): + def test_callback_reauthorize_existing_credential(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -245,7 +246,7 @@ class TestDatasourceOAuthCallback: assert response.status_code == 302 assert "/oauth-callback" in response.location - def test_callback_context_id_from_cookie(self, app): + def test_callback_context_id_from_cookie(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -289,7 +290,7 @@ class TestDatasourceOAuthCallback: class TestDatasourceAuth: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasourceAuth() method = unwrap(api.post) @@ -312,7 +313,7 @@ class TestDatasourceAuth: assert status == 200 - def test_post_invalid_credentials(self, app): + def test_post_invalid_credentials(self, app: Flask): api = DatasourceAuth() method = unwrap(api.post) @@ -334,7 +335,7 @@ class TestDatasourceAuth: with pytest.raises(ValueError): method(api, "notion") - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasourceAuth() method = unwrap(api.get) @@ -355,7 +356,7 @@ class TestDatasourceAuth: assert status == 200 assert response["result"] - def test_post_missing_credentials(self, app): + def test_post_missing_credentials(self, app: Flask): api = DatasourceAuth() method = unwrap(api.post) @@ -372,7 +373,7 @@ class TestDatasourceAuth: with pytest.raises(ValueError): method(api, "notion") - def test_get_empty_list(self, app): + def test_get_empty_list(self, app: Flask): api = DatasourceAuth() method = unwrap(api.get) @@ -395,7 +396,7 @@ class TestDatasourceAuth: class TestDatasourceAuthDeleteApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasourceAuthDeleteApi() method = unwrap(api.post) @@ -418,7 +419,7 @@ class TestDatasourceAuthDeleteApi: assert status == 200 - def test_delete_missing_credential_id(self, app): + def test_delete_missing_credential_id(self, app: Flask): api = DatasourceAuthDeleteApi() method = unwrap(api.post) @@ -437,7 +438,7 @@ class TestDatasourceAuthDeleteApi: class TestDatasourceAuthUpdateApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -460,7 +461,7 @@ class TestDatasourceAuthUpdateApi: assert status == 201 - def test_update_with_credentials_none(self, app): + def test_update_with_credentials_none(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -484,7 +485,7 @@ class TestDatasourceAuthUpdateApi: update_mock.assert_called_once() assert status == 201 - def test_update_name_only(self, app): + def test_update_name_only(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -507,7 +508,7 @@ class TestDatasourceAuthUpdateApi: assert status == 201 - def test_update_with_empty_credentials_dict(self, app): + def test_update_with_empty_credentials_dict(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -533,7 +534,7 @@ class TestDatasourceAuthUpdateApi: class TestDatasourceAuthListApi: - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = DatasourceAuthListApi() method = unwrap(api.get) @@ -553,7 +554,7 @@ class TestDatasourceAuthListApi: assert status == 200 - def test_auth_list_empty(self, app): + def test_auth_list_empty(self, app: Flask): api = DatasourceAuthListApi() method = unwrap(api.get) @@ -574,7 +575,7 @@ class TestDatasourceAuthListApi: assert status == 200 assert response["result"] == [] - def test_hardcode_list_empty(self, app): + def test_hardcode_list_empty(self, app: Flask): api = DatasourceHardCodeAuthListApi() method = unwrap(api.get) @@ -597,7 +598,7 @@ class TestDatasourceAuthListApi: class TestDatasourceHardCodeAuthListApi: - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = DatasourceHardCodeAuthListApi() method = unwrap(api.get) @@ -619,7 +620,7 @@ class TestDatasourceHardCodeAuthListApi: class TestDatasourceAuthOauthCustomClient: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.post) @@ -642,7 +643,7 @@ class TestDatasourceAuthOauthCustomClient: assert status == 200 - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.delete) @@ -662,7 +663,7 @@ class TestDatasourceAuthOauthCustomClient: assert status == 200 - def test_post_empty_payload(self, app): + def test_post_empty_payload(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.post) @@ -685,7 +686,7 @@ class TestDatasourceAuthOauthCustomClient: assert status == 200 - def test_post_disabled_flag(self, app): + def test_post_disabled_flag(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.post) @@ -714,7 +715,7 @@ class TestDatasourceAuthOauthCustomClient: class TestDatasourceAuthDefaultApi: - def test_set_default_success(self, app): + def test_set_default_success(self, app: Flask): api = DatasourceAuthDefaultApi() method = unwrap(api.post) @@ -737,7 +738,7 @@ class TestDatasourceAuthDefaultApi: assert status == 200 - def test_default_missing_id(self, app): + def test_default_missing_id(self, app: Flask): api = DatasourceAuthDefaultApi() method = unwrap(api.post) @@ -756,7 +757,7 @@ class TestDatasourceAuthDefaultApi: class TestDatasourceUpdateProviderNameApi: - def test_update_name_success(self, app): + def test_update_name_success(self, app: Flask): api = DatasourceUpdateProviderNameApi() method = unwrap(api.post) @@ -779,7 +780,7 @@ class TestDatasourceUpdateProviderNameApi: assert status == 200 - def test_update_name_too_long(self, app): + def test_update_name_too_long(self, app: Flask): api = DatasourceUpdateProviderNameApi() method = unwrap(api.post) @@ -799,7 +800,7 @@ class TestDatasourceUpdateProviderNameApi: with pytest.raises(ValueError): method(api, "notion") - def test_update_name_missing_credential_id(self, app): + def test_update_name_missing_credential_id(self, app: Flask): api = DatasourceUpdateProviderNameApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py index 7a8ccde55a..d4c6a775ec 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden from controllers.console import console_ns @@ -25,7 +26,7 @@ class TestDataSourceContentPreviewApi: "credential_id": "cred-1", } - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) @@ -66,7 +67,7 @@ class TestDataSourceContentPreviewApi: assert status == 200 assert response == preview_result - def test_post_forbidden_non_account_user(self, app): + def test_post_forbidden_non_account_user(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) @@ -85,7 +86,7 @@ class TestDataSourceContentPreviewApi: with pytest.raises(Forbidden): method(api, pipeline, "node-1") - def test_post_invalid_payload(self, app): + def test_post_invalid_payload(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) @@ -108,7 +109,7 @@ class TestDataSourceContentPreviewApi: with pytest.raises(ValueError): method(api, pipeline, "node-1") - def test_post_without_credential_id(self, app): + def test_post_without_credential_id(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index 6ef8ccfdbd..63950736c5 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Response -from graphon.variables.types import SegmentType from controllers.console import console_ns from controllers.console.app.error import DraftWorkflowNotExist @@ -16,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor ) from controllers.web.error import InvalidArgumentError, NotFoundError from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from graphon.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 8555900f4e..e28d68ee5a 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -2,6 +2,7 @@ import datetime from unittest.mock import MagicMock, PropertyMock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, Forbidden, NotFound import services @@ -58,7 +59,7 @@ class TestDatasetList: user.is_dataset_editor = True return user - def test_get_success_basic(self, app): + def test_get_success_basic(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -93,7 +94,7 @@ class TestDatasetList: assert resp["total"] == 1 assert resp["data"][0]["embedding_available"] is True - def test_get_with_ids_filter(self, app): + def test_get_with_ids_filter(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -128,7 +129,7 @@ class TestDatasetList: assert status == 200 assert resp["total"] == 2 - def test_get_with_tag_ids(self, app): + def test_get_with_tag_ids(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -161,7 +162,7 @@ class TestDatasetList: assert status == 200 - def test_embedding_available_false(self, app): + def test_embedding_available_false(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -203,7 +204,7 @@ class TestDatasetList: assert resp["data"][0]["embedding_available"] is False - def test_partial_members_permission(self, app): + def test_partial_members_permission(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -242,7 +243,7 @@ class TestDatasetList: class TestDatasetListApiPost: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -290,7 +291,7 @@ class TestDatasetListApiPost: assert status == 201 - def test_post_forbidden(self, app): + def test_post_forbidden(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -310,7 +311,7 @@ class TestDatasetListApiPost: with pytest.raises(Forbidden): method(api) - def test_post_duplicate_name(self, app): + def test_post_duplicate_name(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -335,7 +336,7 @@ class TestDatasetListApiPost: with pytest.raises(DatasetNameDuplicateError): method(api) - def test_post_invalid_payload_missing_name(self, app): + def test_post_invalid_payload_missing_name(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -343,7 +344,7 @@ class TestDatasetListApiPost: with pytest.raises(ValueError): method(api) - def test_post_invalid_indexing_technique(self, app): + def test_post_invalid_indexing_technique(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -356,7 +357,7 @@ class TestDatasetListApiPost: with pytest.raises(ValueError, match="Invalid indexing technique"): method(api) - def test_post_invalid_provider(self, app): + def test_post_invalid_provider(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -371,7 +372,7 @@ class TestDatasetListApiPost: class TestDatasetApiGet: - def test_get_success_basic(self, app): + def test_get_success_basic(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -427,7 +428,7 @@ class TestDatasetApiGet: assert status == 200 assert data["embedding_available"] is True - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -448,7 +449,7 @@ class TestDatasetApiGet: with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -475,7 +476,7 @@ class TestDatasetApiGet: with pytest.raises(Forbidden, match="no access"): method(api, dataset_id) - def test_get_high_quality_embedding_unavailable(self, app): + def test_get_high_quality_embedding_unavailable(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -530,7 +531,7 @@ class TestDatasetApiGet: assert data["embedding_available"] is False - def test_get_partial_members_permission(self, app): + def test_get_partial_members_permission(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -590,7 +591,7 @@ class TestDatasetApiGet: class TestDatasetApiPatch: - def test_patch_success_basic(self, app): + def test_patch_success_basic(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -659,7 +660,7 @@ class TestDatasetApiPatch: assert status == 200 assert result["partial_member_list"] == [] - def test_patch_dataset_not_found(self, app): + def test_patch_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -674,7 +675,7 @@ class TestDatasetApiPatch: with pytest.raises(NotFound, match="Dataset not found"): method(api, "missing") - def test_patch_permission_denied(self, app): + def test_patch_permission_denied(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -704,7 +705,7 @@ class TestDatasetApiPatch: with pytest.raises(Forbidden): method(api, dataset_id) - def test_patch_partial_members_update(self, app): + def test_patch_partial_members_update(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -773,7 +774,7 @@ class TestDatasetApiPatch: assert result["partial_member_list"] == payload["partial_member_list"] - def test_patch_clear_partial_members(self, app): + def test_patch_clear_partial_members(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -843,7 +844,7 @@ class TestDatasetApiPatch: class TestDatasetApiDelete: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -874,7 +875,7 @@ class TestDatasetApiDelete: assert status == 204 assert result == {"result": "success"} - def test_delete_forbidden_no_permission(self, app): + def test_delete_forbidden_no_permission(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -893,7 +894,7 @@ class TestDatasetApiDelete: with pytest.raises(Forbidden): method(api, dataset_id) - def test_delete_dataset_not_found(self, app): + def test_delete_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -917,7 +918,7 @@ class TestDatasetApiDelete: with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) - def test_delete_dataset_in_use(self, app): + def test_delete_dataset_in_use(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -943,7 +944,7 @@ class TestDatasetApiDelete: class TestDatasetUseCheckApi: - def test_get_use_check_true(self, app): + def test_get_use_check_true(self, app: Flask): api = DatasetUseCheckApi() method = unwrap(api.get) @@ -962,7 +963,7 @@ class TestDatasetUseCheckApi: assert status == 200 assert result == {"is_using": True} - def test_get_use_check_false(self, app): + def test_get_use_check_false(self, app: Flask): api = DatasetUseCheckApi() method = unwrap(api.get) @@ -983,7 +984,7 @@ class TestDatasetUseCheckApi: class TestDatasetQueryApi: - def test_get_queries_success(self, app): + def test_get_queries_success(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1027,7 +1028,7 @@ class TestDatasetQueryApi: assert response["has_more"] is False assert len(response["data"]) == 2 - def test_get_queries_dataset_not_found(self, app): + def test_get_queries_dataset_not_found(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1049,7 +1050,7 @@ class TestDatasetQueryApi: with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) - def test_get_queries_permission_denied(self, app): + def test_get_queries_permission_denied(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1078,7 +1079,7 @@ class TestDatasetQueryApi: with pytest.raises(Forbidden): method(api, dataset_id) - def test_get_queries_pagination_has_more(self, app): + def test_get_queries_pagination_has_more(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1152,7 +1153,7 @@ class TestDatasetIndexingEstimateApi: "dataset_id": None, } - def test_post_success_upload_file(self, app): + def test_post_success_upload_file(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) @@ -1193,7 +1194,7 @@ class TestDatasetIndexingEstimateApi: assert status == 200 assert response == {"tokens": 100} - def test_post_file_not_found(self, app): + def test_post_file_not_found(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) @@ -1223,7 +1224,7 @@ class TestDatasetIndexingEstimateApi: with pytest.raises(NotFound): method(api) - def test_post_llm_bad_request_error(self, app): + def test_post_llm_bad_request_error(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() @@ -1258,7 +1259,7 @@ class TestDatasetIndexingEstimateApi: with pytest.raises(ProviderNotInitializeError): method(api) - def test_post_provider_token_not_init(self, app): + def test_post_provider_token_not_init(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() @@ -1293,7 +1294,7 @@ class TestDatasetIndexingEstimateApi: with pytest.raises(ProviderNotInitializeError): method(api) - def test_post_generic_exception(self, app): + def test_post_generic_exception(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() @@ -1330,7 +1331,7 @@ class TestDatasetIndexingEstimateApi: class TestDatasetRelatedAppListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1368,7 +1369,7 @@ class TestDatasetRelatedAppListApi: assert response["total"] == 2 assert response["data"] == [app1, app2] - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1386,7 +1387,7 @@ class TestDatasetRelatedAppListApi: with pytest.raises(NotFound): method(api, "dataset-1") - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1410,7 +1411,7 @@ class TestDatasetRelatedAppListApi: with pytest.raises(Forbidden): method(api, "dataset-1") - def test_get_filters_none_apps(self, app): + def test_get_filters_none_apps(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1449,7 +1450,7 @@ class TestDatasetRelatedAppListApi: class TestDatasetIndexingStatusApi: - def test_get_success_with_documents(self, app): + def test_get_success_with_documents(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) @@ -1490,7 +1491,7 @@ class TestDatasetIndexingStatusApi: assert item["completed_segments"] == 3 assert item["total_segments"] == 3 - def test_get_success_no_documents(self, app): + def test_get_success_no_documents(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) @@ -1510,7 +1511,7 @@ class TestDatasetIndexingStatusApi: assert status == 200 assert response == {"data": []} - def test_segment_counts_different_values(self, app): + def test_segment_counts_different_values(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) @@ -1550,12 +1551,22 @@ class TestDatasetIndexingStatusApi: class TestDatasetApiKeyApi: - def test_get_api_keys_success(self, app): + def test_get_api_keys_success(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.get) mock_key_1 = MagicMock(spec=ApiToken) + mock_key_1.id = "key-1" + mock_key_1.type = "dataset" + mock_key_1.token = "ds-abc" + mock_key_1.last_used_at = None + mock_key_1.created_at = None mock_key_2 = MagicMock(spec=ApiToken) + mock_key_2.id = "key-2" + mock_key_2.type = "dataset" + mock_key_2.token = "ds-def" + mock_key_2.last_used_at = None + mock_key_2.created_at = None with ( app.test_request_context("/"), @@ -1570,13 +1581,26 @@ class TestDatasetApiKeyApi: ): response = method(api) - assert "items" in response - assert response["items"] == [mock_key_1, mock_key_2] + assert "data" in response + assert len(response["data"]) == 2 + assert response["data"][0]["id"] == "key-1" + assert response["data"][0]["token"] == "ds-abc" + assert response["data"][1]["id"] == "key-2" + assert response["data"][1]["token"] == "ds-def" - def test_post_create_api_key_success(self, app): + def test_post_create_api_key_success(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.post) + mock_token = MagicMock() + mock_token.id = "new-key-id" + mock_token.last_used_at = None + mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) + + mock_api_token_cls = MagicMock() + mock_api_token_cls.return_value = mock_token + mock_api_token_cls.generate_api_key.return_value = "dataset-abc123" + with ( app.test_request_context("/"), patch( @@ -1588,8 +1612,8 @@ class TestDatasetApiKeyApi: return_value=3, ), patch( - "controllers.console.datasets.datasets.ApiToken.generate_api_key", - return_value="dataset-abc123", + "controllers.console.datasets.datasets.ApiToken", + mock_api_token_cls, ), patch( "controllers.console.datasets.datasets.db.session.add", @@ -1603,11 +1627,13 @@ class TestDatasetApiKeyApi: response, status = method(api) assert status == 200 - assert isinstance(response, ApiToken) - assert response.token == "dataset-abc123" - assert response.type == "dataset" + assert isinstance(response, dict) + assert response["id"] == "new-key-id" + assert response["token"] == "dataset-abc123" + assert response["type"] == "dataset" + assert response["created_at"] is not None - def test_post_exceed_max_keys(self, app): + def test_post_exceed_max_keys(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.post) @@ -1633,7 +1659,7 @@ class TestDatasetApiKeyApi: class TestDatasetApiDeleteApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasetApiDeleteApi() method = unwrap(api.delete) @@ -1663,7 +1689,7 @@ class TestDatasetApiDeleteApi: assert status == 204 assert response["result"] == "success" - def test_delete_key_not_found(self, app): + def test_delete_key_not_found(self, app: Flask): api = DatasetApiDeleteApi() method = unwrap(api.delete) @@ -1683,7 +1709,7 @@ class TestDatasetApiDeleteApi: class TestDatasetEnableApiApi: - def test_enable_api(self, app): + def test_enable_api(self, app: Flask): api = DatasetEnableApiApi() method = unwrap(api.post) @@ -1699,7 +1725,7 @@ class TestDatasetEnableApiApi: assert status == 200 assert response["result"] == "success" - def test_disable_api(self, app): + def test_disable_api(self, app: Flask): api = DatasetEnableApiApi() method = unwrap(api.post) @@ -1717,7 +1743,7 @@ class TestDatasetEnableApiApi: class TestDatasetApiBaseUrlApi: - def test_get_api_base_url_from_config(self, app): + def test_get_api_base_url_from_config(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) @@ -1732,7 +1758,7 @@ class TestDatasetApiBaseUrlApi: assert response["api_base_url"] == "https://example.com/v1" - def test_get_api_base_url_from_request(self, app): + def test_get_api_base_url_from_request(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) @@ -1747,9 +1773,24 @@ class TestDatasetApiBaseUrlApi: assert response["api_base_url"] == "http://localhost:5000/v1" + def test_get_api_base_url_no_double_v1(self, app: Flask): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + "https://example.com/v1", + ), + ): + response = method(api) + + assert response["api_base_url"] == "https://example.com/v1" + class TestDatasetRetrievalSettingApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetRetrievalSettingApi() method = unwrap(api.get) @@ -1770,7 +1811,7 @@ class TestDatasetRetrievalSettingApi: class TestDatasetRetrievalSettingMockApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetRetrievalSettingMockApi() method = unwrap(api.get) @@ -1787,7 +1828,7 @@ class TestDatasetRetrievalSettingMockApi: class TestDatasetErrorDocs: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetErrorDocs() method = unwrap(api.get) @@ -1810,7 +1851,7 @@ class TestDatasetErrorDocs: assert status == 200 assert response["total"] == 1 - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetErrorDocs() method = unwrap(api.get) @@ -1826,7 +1867,7 @@ class TestDatasetErrorDocs: class TestDatasetPermissionUserListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetPermissionUserListApi() method = unwrap(api.get) @@ -1857,7 +1898,7 @@ class TestDatasetPermissionUserListApi: assert status == 200 assert response["data"] == users - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetPermissionUserListApi() method = unwrap(api.get) @@ -1883,7 +1924,7 @@ class TestDatasetPermissionUserListApi: class TestDatasetAutoDisableLogApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetAutoDisableLogApi() method = unwrap(api.get) @@ -1906,7 +1947,7 @@ class TestDatasetAutoDisableLogApi: assert status == 200 assert response == logs - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetAutoDisableLogApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index ce2278de4f..ff9e1736d2 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -1,6 +1,8 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden, NotFound import services @@ -215,24 +217,30 @@ class TestDatasetDocumentListApi: method = unwrap(api.post) payload = {"indexing_technique": "economy"} + created_dataset = SimpleNamespace(id="ds-1", name="Dataset", indexing_technique="economy") + created_document = SimpleNamespace(id="doc-1", name="Document", doc_metadata_details=None) with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=created_dataset, + ), patch( "controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id", - return_value=([MagicMock()], "batch-1"), + return_value=([created_document], "batch-1"), ), ): response = method(api, "ds-1") assert "documents" in response - def test_post_forbidden(self, app): + def test_post_forbidden(self, app: Flask): api = DatasetDocumentListApi() method = unwrap(api.post) @@ -388,7 +396,7 @@ class TestDocumentDownloadApi: class TestDocumentProcessingApi: - def test_processing_forbidden_when_not_editor(self, app): + def test_processing_forbidden_when_not_editor(self, app: Flask): api = DocumentProcessingApi() method = unwrap(api.patch) @@ -1178,7 +1186,7 @@ class TestDocumentPermissionCases: "preview": [], } - def test_document_tenant_mismatch(self, app): + def test_document_tenant_mismatch(self, app: Flask): api = DocumentApi() method = unwrap(api.get) @@ -1246,7 +1254,7 @@ class TestDocumentPermissionCases: assert status == 200 assert response["mode"] == "custom" - def test_process_rule_permission_denied(self, app): + def test_process_rule_permission_denied(self, app: Flask): api = GetProcessRuleApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index 693b06e95b..412edb9dfe 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden, NotFound import services @@ -82,7 +83,7 @@ def test_get_segment_with_summary(monkeypatch): class TestDatasetDocumentSegmentListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -132,7 +133,7 @@ class TestDatasetDocumentSegmentListApi: assert status == 200 - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -150,7 +151,7 @@ class TestDatasetDocumentSegmentListApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -176,7 +177,7 @@ class TestDatasetDocumentSegmentListApi: class TestDatasetDocumentSegmentApi: - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -221,7 +222,7 @@ class TestDatasetDocumentSegmentApi: assert status == 200 assert response["result"] == "success" - def test_patch_document_indexing_in_progress(self, app): + def test_patch_document_indexing_in_progress(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -264,7 +265,7 @@ class TestDatasetDocumentSegmentApi: with pytest.raises(InvalidActionError): method(api, "ds-1", "doc-1", "disable") - def test_patch_llm_bad_request(self, app): + def test_patch_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -308,7 +309,7 @@ class TestDatasetDocumentSegmentApi: with pytest.raises(ProviderNotInitializeError): method(api, "ds-1", "doc-1", "enable") - def test_patch_provider_token_not_init(self, app): + def test_patch_provider_token_not_init(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -354,7 +355,7 @@ class TestDatasetDocumentSegmentApi: class TestDatasetDocumentSegmentAddApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -413,7 +414,7 @@ class TestDatasetDocumentSegmentAddApi: assert status == 200 assert response["data"]["id"] == "seg-1" - def test_post_llm_bad_request(self, app): + def test_post_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -452,7 +453,7 @@ class TestDatasetDocumentSegmentAddApi: with pytest.raises(ProviderNotInitializeError): method(api, "ds-1", "doc-1") - def test_post_provider_token_not_init(self, app): + def test_post_provider_token_not_init(self, app: Flask): api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -493,7 +494,7 @@ class TestDatasetDocumentSegmentAddApi: class TestDatasetDocumentSegmentUpdateApi: - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = DatasetDocumentSegmentUpdateApi() method = unwrap(api.patch) @@ -551,7 +552,7 @@ class TestDatasetDocumentSegmentUpdateApi: assert status == 200 assert "data" in response - def test_patch_llm_bad_request(self, app): + def test_patch_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentUpdateApi() method = unwrap(api.patch) @@ -596,7 +597,7 @@ class TestDatasetDocumentSegmentUpdateApi: class TestDatasetDocumentSegmentBatchImportApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -638,7 +639,7 @@ class TestDatasetDocumentSegmentBatchImportApi: assert status == 200 assert response["job_status"] == "waiting" - def test_post_dataset_not_found(self, app): + def test_post_dataset_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -659,7 +660,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_post_document_not_found(self, app): + def test_post_document_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -684,7 +685,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_post_upload_file_not_found(self, app): + def test_post_upload_file_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -713,7 +714,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_post_invalid_file_type(self, app): + def test_post_invalid_file_type(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -745,7 +746,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(ValueError): method(api, "ds-1", "doc-1") - def test_post_async_task_failure(self, app): + def test_post_async_task_failure(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -783,7 +784,7 @@ class TestDatasetDocumentSegmentBatchImportApi: assert status == 500 assert "error" in response - def test_get_job_not_found_in_redis(self, app): + def test_get_job_not_found_in_redis(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.get) @@ -799,7 +800,7 @@ class TestDatasetDocumentSegmentBatchImportApi: class TestChildChunkAddApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = ChildChunkAddApi() method = unwrap(api.post) @@ -852,7 +853,7 @@ class TestChildChunkAddApi: assert status == 200 assert response["data"]["id"] == "cc-1" - def test_post_child_chunk_indexing_error(self, app): + def test_post_child_chunk_indexing_error(self, app: Flask): api = ChildChunkAddApi() method = unwrap(api.post) @@ -897,7 +898,7 @@ class TestChildChunkAddApi: class TestChildChunkUpdateApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = ChildChunkUpdateApi() method = unwrap(api.delete) @@ -941,7 +942,7 @@ class TestChildChunkUpdateApi: assert status == 204 assert response["result"] == "success" - def test_delete_child_chunk_index_error(self, app): + def test_delete_child_chunk_index_error(self, app: Flask): api = ChildChunkUpdateApi() method = unwrap(api.delete) @@ -984,7 +985,7 @@ class TestChildChunkUpdateApi: class TestSegmentListAdvancedCases: - def test_segment_list_with_keyword_filter(self, app): + def test_segment_list_with_keyword_filter(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -1035,7 +1036,7 @@ class TestSegmentListAdvancedCases: assert status == 200 assert response["total"] == 1 - def test_segment_list_permission_denied(self, app): + def test_segment_list_permission_denied(self, app: Flask): """Test segment list with permission denied""" api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -1058,7 +1059,7 @@ class TestSegmentListAdvancedCases: with pytest.raises(Forbidden): method(api, "ds-1", "doc-1") - def test_segment_list_dataset_not_found(self, app): + def test_segment_list_dataset_not_found(self, app: Flask): """Test segment list with dataset not found""" api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -1079,7 +1080,7 @@ class TestSegmentListAdvancedCases: class TestSegmentOperationCases: - def test_segment_add_with_provider_token_error(self, app): + def test_segment_add_with_provider_token_error(self, app: Flask): """Test segment add with provider token not initialized""" api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -1117,7 +1118,7 @@ class TestSegmentOperationCases: with pytest.raises(ProviderTokenNotInitError): method(api, "ds-1", "doc-1") - def test_batch_import_with_document_not_found(self, app): + def test_batch_import_with_document_not_found(self, app: Flask): """Test batch import with document not found""" api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -1146,7 +1147,7 @@ class TestSegmentOperationCases: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_batch_import_with_invalid_file(self, app): + def test_batch_import_with_invalid_file(self, app: Flask): """Test batch import with invalid file type""" api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -1181,7 +1182,7 @@ class TestSegmentOperationCases: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_batch_import_with_async_task_failure(self, app): + def test_batch_import_with_async_task_failure(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -1226,7 +1227,7 @@ class TestSegmentOperationCases: assert status == 500 assert "error" in response - def test_batch_import_get_job_not_found(self, app): + def test_batch_import_get_job_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index 514bbbe040..7254bf7670 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -57,7 +57,7 @@ def mock_auth(monkeypatch, current_user): class TestExternalApiTemplateListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ExternalApiTemplateListApi() method = unwrap(api.get) @@ -78,7 +78,7 @@ class TestExternalApiTemplateListApi: assert resp["total"] == 1 assert resp["data"][0]["id"] == "1" - def test_post_forbidden(self, app, current_user): + def test_post_forbidden(self, app: Flask, current_user): current_user.is_dataset_editor = False api = ExternalApiTemplateListApi() method = unwrap(api.post) @@ -93,7 +93,7 @@ class TestExternalApiTemplateListApi: with pytest.raises(Forbidden): method(api) - def test_post_duplicate_name(self, app): + def test_post_duplicate_name(self, app: Flask): api = ExternalApiTemplateListApi() method = unwrap(api.post) @@ -114,7 +114,7 @@ class TestExternalApiTemplateListApi: class TestExternalApiTemplateApi: - def test_get_not_found(self, app): + def test_get_not_found(self, app: Flask): api = ExternalApiTemplateApi() method = unwrap(api.get) @@ -129,7 +129,7 @@ class TestExternalApiTemplateApi: with pytest.raises(NotFound): method(api, "api-id") - def test_delete_forbidden(self, app, current_user): + def test_delete_forbidden(self, app: Flask, current_user): current_user.has_edit_permission = False current_user.is_dataset_operator = False @@ -142,7 +142,7 @@ class TestExternalApiTemplateApi: class TestExternalApiUseCheckApi: - def test_get_scopes_usage_check_to_current_tenant(self, app): + def test_get_scopes_usage_check_to_current_tenant(self, app: Flask): api = ExternalApiUseCheckApi() method = unwrap(api.get) @@ -162,7 +162,7 @@ class TestExternalApiUseCheckApi: class TestExternalDatasetCreateApi: - def test_create_success(self, app): + def test_create_success(self, app: Flask): api = ExternalDatasetCreateApi() method = unwrap(api.post) @@ -206,7 +206,7 @@ class TestExternalDatasetCreateApi: assert status == 201 - def test_create_forbidden(self, app, current_user): + def test_create_forbidden(self, app: Flask, current_user): current_user.is_dataset_editor = False api = ExternalDatasetCreateApi() method = unwrap(api.post) @@ -226,7 +226,7 @@ class TestExternalDatasetCreateApi: class TestExternalKnowledgeHitTestingApi: - def test_hit_testing_dataset_not_found(self, app): + def test_hit_testing_dataset_not_found(self, app: Flask): api = ExternalKnowledgeHitTestingApi() method = unwrap(api.post) @@ -241,7 +241,7 @@ class TestExternalKnowledgeHitTestingApi: with pytest.raises(NotFound): method(api, "dataset-id") - def test_hit_testing_success(self, app): + def test_hit_testing_success(self, app: Flask): api = ExternalKnowledgeHitTestingApi() method = unwrap(api.post) @@ -266,7 +266,7 @@ class TestExternalKnowledgeHitTestingApi: class TestBedrockRetrievalApi: - def test_bedrock_retrieval(self, app): + def test_bedrock_retrieval(self, app: Flask): api = BedrockRetrievalApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py index 726c0a5cf3..09ed2aaf69 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -99,6 +99,57 @@ class TestHitTestingApi: assert "records" in result assert result["records"] == [] + def test_hit_testing_success_with_optional_record_fields(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "what is vector search", + } + records = [ + { + "segment": None, + "child_chunks": [], + "score": None, + "tsne_position": None, + "files": [], + "summary": None, + } + ] + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + ), + patch.object( + HitTestingApi, + "perform_hit_testing", + return_value={"query": payload["query"], "records": records}, + ), + ): + result = method(api, dataset_id) + + assert result["query"] == payload["query"] + assert result["records"] == records + def test_hit_testing_dataset_not_found(self, app, dataset_id): api = HitTestingApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index 710c9be684..d29b34beb2 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -21,6 +20,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService @@ -134,6 +134,42 @@ class TestPerformHitTesting: assert result["query"] == "hello" assert result["records"] == [] + def test_success_normalizes_legacy_query_and_nullable_list_fields(self, dataset): + response = { + "query": {"content": "hello"}, + "records": [ + { + "segment": {"id": "segment-1", "keywords": None}, + "child_chunks": None, + "files": None, + "score": 0.8, + } + ], + } + + with ( + patch.object( + HitTestingService, + "retrieve", + return_value=response, + ), + patch( + "controllers.console.datasets.hit_testing_base.marshal", + return_value=response["records"], + ), + ): + result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + assert result["query"] == "hello" + assert result["records"] == [ + { + "segment": {"id": "segment-1", "keywords": []}, + "child_chunks": [], + "files": [], + "score": 0.8, + } + ] + def test_index_not_initialized(self, dataset): with patch.object( HitTestingService, diff --git a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py index de834c2d4d..0105aacd65 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py @@ -269,7 +269,7 @@ class TestDatasetMetadataApi: class TestDatasetMetadataBuiltInFieldApi: - def test_get_built_in_fields(self, app): + def test_get_built_in_fields(self, app: Flask): api = DatasetMetadataBuiltInFieldApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index 66c9ba48c5..b4b57022e2 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -2,7 +2,6 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import controllers.console.explore.audio as audio_module @@ -20,6 +19,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_banner.py b/api/tests/unit_tests/controllers/console/explore/test_banner.py index c8f674f515..d1cb6b6a03 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_banner.py +++ b/api/tests/unit_tests/controllers/console/explore/test_banner.py @@ -1,6 +1,8 @@ from datetime import datetime from unittest.mock import MagicMock, patch +from flask import Flask + import controllers.console.explore.banner as banner_module from models.enums import BannerStatus @@ -12,7 +14,7 @@ def unwrap(func): class TestBannerApi: - def test_get_banners_with_requested_language(self, app): + def test_get_banners_with_requested_language(self, app: Flask): api = banner_module.BannerApi() method = unwrap(api.get) @@ -41,7 +43,7 @@ class TestBannerApi: } ] - def test_get_banners_fallback_to_en_us(self, app): + def test_get_banners_fallback_to_en_us(self, app: Flask): api = banner_module.BannerApi() method = unwrap(api.get) @@ -76,7 +78,7 @@ class TestBannerApi: } ] - def test_get_banners_default_language_en_us(self, app): + def test_get_banners_default_language_en_us(self, app: Flask): api = banner_module.BannerApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 2e4ca4f2a4..3d41489435 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError +from flask import Flask from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -22,6 +22,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, @@ -54,7 +55,7 @@ def make_message(): class TestMessageListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.MessageListApi() method = unwrap(api.get) @@ -96,7 +97,7 @@ class TestMessageListApi: with pytest.raises(NotChatAppError): method(installed_app) - def test_conversation_not_exists(self, app): + def test_conversation_not_exists(self, app: Flask): api = module.MessageListApi() method = unwrap(api.get) @@ -118,7 +119,7 @@ class TestMessageListApi: with pytest.raises(NotFound): method(installed_app) - def test_first_message_not_exists(self, app): + def test_first_message_not_exists(self, app: Flask): api = module.MessageListApi() method = unwrap(api.get) @@ -142,7 +143,7 @@ class TestMessageListApi: class TestMessageFeedbackApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = module.MessageFeedbackApi() method = unwrap(api.post) @@ -161,7 +162,7 @@ class TestMessageFeedbackApi: assert result["result"] == "success" - def test_message_not_exists(self, app): + def test_message_not_exists(self, app: Flask): api = module.MessageFeedbackApi() method = unwrap(api.post) @@ -182,7 +183,7 @@ class TestMessageFeedbackApi: class TestMessageMoreLikeThisApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -221,7 +222,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(NotCompletionAppError): method(installed_app, "mid") - def test_more_like_this_disabled(self, app): + def test_more_like_this_disabled(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -243,7 +244,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(AppMoreLikeThisDisabledError): method(installed_app, "mid") - def test_message_not_exists_more_like_this(self, app): + def test_message_not_exists_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -265,7 +266,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(NotFound): method(installed_app, "mid") - def test_provider_not_init_more_like_this(self, app): + def test_provider_not_init_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -287,7 +288,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(ProviderNotInitializeError): method(installed_app, "mid") - def test_quota_exceeded_more_like_this(self, app): + def test_quota_exceeded_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -309,7 +310,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(ProviderQuotaExceededError): method(installed_app, "mid") - def test_model_not_support_more_like_this(self, app): + def test_model_not_support_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -331,7 +332,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(ProviderModelCurrentlyNotSupportError): method(installed_app, "mid") - def test_invoke_error_more_like_this(self, app): + def test_invoke_error_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -353,7 +354,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(CompletionRequestError): method(installed_app, "mid") - def test_unexpected_error_more_like_this(self, app): + def test_unexpected_error_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py index 02c7507ea7..557fded37e 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -1,6 +1,9 @@ from unittest.mock import MagicMock, patch +from flask import Flask + import controllers.console.explore.recommended_app as module +from models.model import AppMode, IconType def unwrap(func): @@ -10,7 +13,7 @@ def unwrap(func): class TestRecommendedAppListApi: - def test_get_with_language_param(self, app): + def test_get_with_language_param(self, app: Flask): api = module.RecommendedAppListApi() method = unwrap(api.get) @@ -30,7 +33,7 @@ class TestRecommendedAppListApi: service_mock.assert_called_once_with("en-US") assert result == result_data - def test_get_fallback_to_user_language(self, app): + def test_get_fallback_to_user_language(self, app: Flask): api = module.RecommendedAppListApi() method = unwrap(api.get) @@ -50,7 +53,7 @@ class TestRecommendedAppListApi: service_mock.assert_called_once_with("fr-FR") assert result == result_data - def test_get_fallback_to_default_language(self, app): + def test_get_fallback_to_default_language(self, app: Flask): api = module.RecommendedAppListApi() method = unwrap(api.get) @@ -72,7 +75,7 @@ class TestRecommendedAppListApi: class TestRecommendedAppApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.RecommendedAppApi() method = unwrap(api.get) @@ -90,3 +93,48 @@ class TestRecommendedAppApi: service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111") assert result == result_data + + +class TestRecommendedAppResponseModels: + def test_recommended_app_info_response_computes_icon_url(self): + with patch.object(module, "build_icon_url", return_value="https://signed/icon.png"): + payload = module.RecommendedAppInfoResponse.model_validate( + { + "id": "app-1", + "name": "App", + "mode": AppMode.CHAT, + "icon": "icon.png", + "icon_type": IconType.IMAGE, + "icon_background": "#fff", + } + ).model_dump(mode="json") + + assert payload["icon_url"] == "https://signed/icon.png" + + def test_recommended_app_list_response_serialization(self): + response = module.RecommendedAppListResponse.model_validate( + { + "recommended_apps": [ + { + "app": { + "id": "app-1", + "name": "App", + "mode": "chat", + "icon": "icon.png", + "icon_type": "emoji", + "icon_background": "#fff", + }, + "app_id": "app-1", + "description": "desc", + "category": "cat", + "position": 1, + "is_listed": True, + "can_trial": False, + } + ], + "categories": ["cat"], + } + ).model_dump(mode="json") + + assert response["recommended_apps"][0]["app_id"] == "app-1" + assert response["categories"] == ["cat"] diff --git a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py index bb7cdd55c4..71241890e9 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, PropertyMock, patch from uuid import uuid4 import pytest +from flask import Flask from werkzeug.exceptions import NotFound import controllers.console.explore.saved_message as module @@ -42,7 +43,7 @@ def payload_patch(): class TestSavedMessageListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.SavedMessageListApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 04beb31389..14f00e6295 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.model_runtime.errors.invoke import InvokeError +from flask import Flask from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -26,6 +26,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode @@ -88,13 +89,13 @@ def valid_parameters(): class TestTrialAppWorkflowRunApi: - def test_not_workflow_app(self, app): + def test_not_workflow_app(self, app: Flask): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) with app.test_request_context("/"): with pytest.raises(NotWorkflowAppError): - method(MagicMock(mode=AppMode.CHAT)) + method(api, MagicMock(mode=AppMode.CHAT)) def test_success(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -106,7 +107,7 @@ class TestTrialAppWorkflowRunApi: patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), patch.object(module.RecommendedAppService, "add_trial_app_record"), ): - result = method(trial_app_workflow) + result = method(api, trial_app_workflow) assert result is not None @@ -124,7 +125,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_quota_exceeded(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -140,7 +141,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_model_not_support(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -156,7 +157,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_invoke_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -172,7 +173,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(CompletionRequestError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_rate_limit_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -188,7 +189,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(InvokeRateLimitHttpError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_value_error(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -204,7 +205,7 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ValueError): - method(trial_app_workflow) + method(api, trial_app_workflow) def test_workflow_generic_exception(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowRunApi() @@ -220,11 +221,11 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(InternalServerError): - method(trial_app_workflow) + method(api, trial_app_workflow) class TestTrialChatApi: - def test_not_chat_app(self, app): + def test_not_chat_app(self, app: Flask): api = module.TrialChatApi() method = unwrap(api.post) @@ -408,7 +409,7 @@ class TestTrialChatApi: class TestTrialCompletionApi: - def test_not_completion_app(self, app): + def test_not_completion_app(self, app: Flask): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -560,13 +561,13 @@ class TestTrialCompletionApi: class TestTrialMessageSuggestedQuestionApi: - def test_not_chat_app(self, app): + def test_not_chat_app(self, app: Flask): api = module.TrialMessageSuggestedQuestionApi() method = unwrap(api.get) with app.test_request_context("/"): with pytest.raises(NotChatAppError): - method(api, MagicMock(mode="completion"), str(uuid4())) + method(MagicMock(mode="completion"), str(uuid4())) def test_success(self, app, trial_app_chat, account): api = module.TrialMessageSuggestedQuestionApi() @@ -581,7 +582,7 @@ class TestTrialMessageSuggestedQuestionApi: return_value=["q1", "q2"], ), ): - result = method(api, trial_app_chat, str(uuid4())) + result = method(trial_app_chat, str(uuid4())) assert result == {"data": ["q1", "q2"]} @@ -599,7 +600,7 @@ class TestTrialMessageSuggestedQuestionApi: ), ): with pytest.raises(NotFound): - method(api, trial_app_chat, str(uuid4())) + method(trial_app_chat, str(uuid4())) class TestTrialAppParameterApi: @@ -931,7 +932,7 @@ class TestTrialAppWorkflowTaskStopApi: with app.test_request_context("/"): with pytest.raises(NotWorkflowAppError): - method(trial_app_chat, str(uuid4())) + method(api, trial_app_chat, str(uuid4())) def test_success(self, app, trial_app_workflow, account): api = module.TrialAppWorkflowTaskStopApi() @@ -944,7 +945,7 @@ class TestTrialAppWorkflowTaskStopApi: patch.object(module.AppQueueManager, "set_stop_flag_no_user_check") as mock_set_flag, patch.object(module.GraphEngineManager, "send_stop_command") as mock_send_cmd, ): - result = method(trial_app_workflow, task_id) + result = method(api, trial_app_workflow, task_id) assert result == {"result": "success"} mock_set_flag.assert_called_once_with(task_id) @@ -952,7 +953,7 @@ class TestTrialAppWorkflowTaskStopApi: class TestTrialSitApi: - def test_no_site(self, app): + def test_no_site(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) app_model = MagicMock() @@ -963,7 +964,7 @@ class TestTrialSitApi: with pytest.raises(Forbidden): method(api, app_model) - def test_archived_tenant(self, app): + def test_archived_tenant(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) @@ -978,7 +979,7 @@ class TestTrialSitApi: with pytest.raises(Forbidden): method(api, app_model) - def test_success(self, app): + def test_success(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index e89b89c8b1..8b47da25fb 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -1,13 +1,15 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, PropertyMock, patch import pytest from flask import Flask from werkzeug.exceptions import Forbidden +import controllers.console.tag.tags as module from controllers.console import console_ns from controllers.console.tag.tags import ( - TagBindingCreateApi, - TagBindingDeleteApi, + TagBindingCollectionApi, + TagBindingRemoveApi, TagListApi, TagUpdateDeleteApi, ) @@ -71,7 +73,7 @@ def payload_patch(): class TestTagListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = TagListApi() method = unwrap(api.get) @@ -83,13 +85,20 @@ class TestTagListApi: ), patch( "controllers.console.tag.tags.TagService.get_tags", - return_value=[{"id": "1", "name": "tag"}], + return_value=[ + SimpleNamespace( + id="1", + name="tag", + type=TagType.KNOWLEDGE, + binding_count=1, + ) + ], ), ): result, status = method(api) assert status == 200 - assert isinstance(result, list) + assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}] def test_post_success(self, app, admin_user, tag, payload_patch): api = TagListApi() @@ -113,8 +122,9 @@ class TestTagListApi: assert status == 200 assert result["name"] == "test-tag" + assert result["binding_count"] == "0" - def test_post_forbidden(self, app, readonly_user, payload_patch): + def test_post_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagListApi() method = unwrap(api.post) @@ -158,9 +168,9 @@ class TestTagUpdateDeleteApi: result, status = method(api, "tag-1") assert status == 200 - assert result["binding_count"] == 3 + assert result["binding_count"] == "3" - def test_patch_forbidden(self, app, readonly_user, payload_patch): + def test_patch_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagUpdateDeleteApi() method = unwrap(api.patch) @@ -195,9 +205,9 @@ class TestTagUpdateDeleteApi: assert status == 204 -class TestTagBindingCreateApi: +class TestTagBindingCollectionApi: def test_create_success(self, app, admin_user, payload_patch): - api = TagBindingCreateApi() + api = TagBindingCollectionApi() method = unwrap(api.post) payload = { @@ -221,8 +231,8 @@ class TestTagBindingCreateApi: assert status == 200 assert result["result"] == "success" - def test_create_forbidden(self, app, readonly_user, payload_patch): - api = TagBindingCreateApi() + def test_create_forbidden(self, app: Flask, readonly_user, payload_patch): + api = TagBindingCollectionApi() method = unwrap(api.post) with app.test_request_context("/", json={}): @@ -237,13 +247,13 @@ class TestTagBindingCreateApi: method(api) -class TestTagBindingDeleteApi: +class TestTagBindingRemoveApi: def test_remove_success(self, app, admin_user, payload_patch): - api = TagBindingDeleteApi() + api = TagBindingRemoveApi() method = unwrap(api.post) payload = { - "tag_id": "tag-1", + "tag_ids": ["tag-1", "tag-2"], "target_id": "target-1", "type": "knowledge", } @@ -260,11 +270,13 @@ class TestTagBindingDeleteApi: result, status = method(api) delete_mock.assert_called_once() + delete_payload = delete_mock.call_args.args[0] + assert delete_payload.tag_ids == ["tag-1", "tag-2"] assert status == 200 assert result["result"] == "success" - def test_remove_forbidden(self, app, readonly_user, payload_patch): - api = TagBindingDeleteApi() + def test_remove_forbidden(self, app: Flask, readonly_user, payload_patch): + api = TagBindingRemoveApi() method = unwrap(api.post) with app.test_request_context("/", json={}): @@ -277,3 +289,43 @@ class TestTagBindingDeleteApi: ): with pytest.raises(Forbidden): method(api) + + +class TestTagResponseModel: + def test_tag_response_normalizes_enum_type(self): + payload = module.TagResponse.model_validate( + {"id": "tag-1", "name": "tag", "type": TagType.KNOWLEDGE, "binding_count": 1} + ).model_dump(mode="json") + + assert payload["type"] == "knowledge" + assert payload["binding_count"] == "1" + + +class TestTagBindingRouteMetadata: + def test_write_routes_are_not_deprecated(self): + assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True + assert TagBindingRemoveApi.post.__apidoc__.get("deprecated") is not True + + def test_write_routes_have_stable_operation_ids(self): + assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding" + assert TagBindingRemoveApi.post.__apidoc__["id"] == "remove_tag_bindings" + + def test_write_routes_are_registered(self): + route_map = { + resource.__name__: urls + for resource, urls, _route_doc, _kwargs in console_ns.resources + if resource.__name__ + in { + "TagBindingCollectionApi", + "TagBindingRemoveApi", + } + } + + assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",) + assert route_map["TagBindingRemoveApi"] == ("/tag-bindings/remove",) + + def test_legacy_write_routes_are_not_registered(self): + urls = {url for _resource, resource_urls, _route_doc, _kwargs in console_ns.resources for url in resource_urls} + + assert "/tag-bindings/create" not in urls + assert "/tag-bindings/" not in urls diff --git a/api/tests/unit_tests/controllers/console/test_files.py b/api/tests/unit_tests/controllers/console/test_files.py index 5df9daa7f8..eebc6f9d60 100644 --- a/api/tests/unit_tests/controllers/console/test_files.py +++ b/api/tests/unit_tests/controllers/console/test_files.py @@ -82,7 +82,7 @@ def mock_file_service(mock_db): class TestFileApiGet: - def test_get_upload_config(self, app): + def test_get_upload_config(self, app: Flask): api = FileApi() get_method = unwrap(api.get) @@ -290,7 +290,7 @@ class TestFilePreviewApi: class TestFileSupportTypeApi: - def test_get_supported_types(self, app): + def test_get_supported_types(self, app: Flask): api = FileSupportTypeApi() get_method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/test_human_input_form.py b/api/tests/unit_tests/controllers/console/test_human_input_form.py index 232b6eee79..ebf803cac9 100644 --- a/api/tests/unit_tests/controllers/console/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/console/test_human_input_form.py @@ -122,6 +122,35 @@ def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) handler(api, form_token="token") +def test_post_form_rejects_webapp_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.STANDALONE_WEB_APP) + + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_by_token(self, _token): + return form + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="user-1"), "tenant-1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/console/api/form/human_input/token", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + with pytest.raises(NotFoundError): + handler(api, form_token="token") + + def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None: submit_mock = Mock() form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE) diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 7f9fe9cbf9..4b4f968c8f 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -11,7 +11,7 @@ from controllers.console.workspace.account import ( ChangeEmailSendEmailApi, CheckEmailUnique, ) -from models import Account +from models import Account, AccountStatus from services.account_service import AccountService @@ -24,16 +24,12 @@ def app(): return app -def _mock_wraps_db(mock_db): - mock_db.session.query.return_value.first.return_value = MagicMock() - - def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account: tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id") account = Account(name=account_id, email=email) account.email = email account.id = account_id - account.status = "active" + account.status = AccountStatus.ACTIVE account._current_tenant = tenant_obj return account @@ -62,13 +58,15 @@ class TestChangeEmailSend: mock_get_change_data, mock_current_account, mock_db, - app, + app: Flask, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") mock_current_account.return_value = (mock_account, None) - mock_get_change_data.return_value = {"email": "current@example.com"} + mock_get_change_data.return_value = { + "email": "current@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + } mock_send_email.return_value = "token-abc" with app.test_request_context( @@ -85,12 +83,54 @@ class TestChangeEmailSend: email="new@example.com", old_email="current@example.com", language="en-US", - phase="new_email", + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, ) mock_extract_ip.assert_called_once() mock_is_ip_limit.assert_called_once_with("127.0.0.1") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.send_change_email_email") + @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) + @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified( + self, + mock_features, + mock_csrf, + mock_extract_ip, + mock_is_ip_limit, + mock_send_email, + mock_get_change_data, + mock_current_account, + mock_db, + app: Flask, + ): + """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_account = _build_account("current@example.com", "acc1") + mock_current_account.return_value = (mock_account, None) + mock_get_change_data.return_value = { + "email": "current@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } + + with app.test_request_context( + "/account/change-email", + method="POST", + json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailSendEmailApi().post() + + mock_send_email.assert_not_called() + class TestChangeEmailValidity: @patch("controllers.console.wraps.db") @@ -115,14 +155,18 @@ class TestChangeEmailValidity: mock_reset_rate, mock_current_account, mock_db, - app, + app: Flask, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("user@example.com", "acc2") mock_current_account.return_value = (mock_account, None) mock_is_rate_limit.return_value = False - mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"} + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } mock_generate_token.return_value = (None, "new-token") with app.test_request_context( @@ -138,11 +182,166 @@ class TestChangeEmailValidity: mock_add_rate.assert_not_called() mock_revoke_token.assert_called_once_with("token-123") mock_generate_token.assert_called_once_with( - "user@example.com", code="1234", old_email="old@example.com", additional_data={} + "user@example.com", + code="1234", + old_email="old@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, + }, ) mock_reset_rate.assert_called_once_with("user@example.com") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_upgrade_new_phase_token_to_new_verified( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app: Flask, + ): + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "new@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, + } + mock_generate_token.return_value = (None, "new-verified-token") + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "new@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + response = ChangeEmailCheckApi().post() + + assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"} + mock_generate_token.assert_called_once_with( + "new@example.com", + code="1234", + old_email="old@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + }, + ) + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_validity_when_token_phase_is_unknown( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app: Flask, + ): + """A token whose phase marker is a string but not a known transition must be rejected.""" + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: "something_else", + } + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailCheckApi().post() + + mock_revoke_token.assert_not_called() + mock_generate_token.assert_not_called() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_validity_when_token_has_no_phase( + self, + mock_features, + mock_csrf, + mock_is_rate_limit, + mock_get_data, + mock_add_rate, + mock_revoke_token, + mock_generate_token, + mock_reset_rate, + mock_current_account, + mock_db, + app: Flask, + ): + """A token minted without a phase marker (e.g. a hand-crafted token) must not validate.""" + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + mock_is_rate_limit.return_value = False + mock_get_data.return_value = { + "email": "user@example.com", + "code": "1234", + "old_email": "old@example.com", + } + + with app.test_request_context( + "/account/change-email/validity", + method="POST", + json={"email": "user@example.com", "code": "1234", "token": "token-123"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailCheckApi().post() + + mock_revoke_token.assert_not_called() + mock_generate_token.assert_not_called() + class TestChangeEmailReset: @patch("controllers.console.wraps.db") @@ -167,15 +366,18 @@ class TestChangeEmailReset: mock_send_notify, mock_current_account, mock_db, - app, + app: Flask, ): - _mock_wraps_db(mock_db) mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True - mock_get_data.return_value = {"old_email": "OLD@example.com"} + mock_get_data.return_value = { + "email": "new@example.com", + "old_email": "OLD@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } mock_account_after_update = _build_account("new@example.com", "acc3-updated") mock_update_account.return_value = mock_account_after_update @@ -194,12 +396,158 @@ class TestChangeEmailReset: mock_send_notify.assert_called_once_with(email="new@example.com") mock_csrf.assert_called_once() + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_phase_is_not_new_verified( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app: Flask, + ): + """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + # Simulate a token straight out of step #1 (phase=old_email) — exactly + # the replay used in the advisory PoC. + mock_get_data.return_value = { + "email": "old@example.com", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD, + } + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "attacker@example.com", "token": "token-from-step1"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.workspace.account.current_account_with_tenant") + @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") + @patch("controllers.console.workspace.account.AccountService.update_account_email") + @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") + @patch("controllers.console.workspace.account.AccountService.get_change_email_data") + @patch("controllers.console.workspace.account.AccountService.check_email_unique") + @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") + @patch("libs.login.check_csrf_token", return_value=None) + @patch("controllers.console.wraps.FeatureService.get_system_features") + def test_should_reject_reset_when_token_email_differs_from_payload_new_email( + self, + mock_features, + mock_csrf, + mock_is_freeze, + mock_check_unique, + mock_get_data, + mock_revoke_token, + mock_update_account, + mock_send_notify, + mock_current_account, + mock_db, + app: Flask, + ): + """A verified token for address A must not be replayed to change to address B.""" + from controllers.console.auth.error import InvalidTokenError + + mock_features.return_value = SimpleNamespace(enable_change_email=True) + current_user = _build_account("old@example.com", "acc3") + mock_current_account.return_value = (current_user, None) + mock_is_freeze.return_value = False + mock_check_unique.return_value = True + mock_get_data.return_value = { + "email": "verified@example.com", + "old_email": "old@example.com", + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED, + } + + with app.test_request_context( + "/account/change-email/reset", + method="POST", + json={"new_email": "attacker@example.com", "token": "token-verified"}, + ): + _set_logged_in_user(_build_account("tester@example.com", "tester")) + with pytest.raises(InvalidTokenError): + ChangeEmailResetApi().post() + + mock_revoke_token.assert_not_called() + mock_update_account.assert_not_called() + mock_send_notify.assert_not_called() + + +class TestAccountServiceSendChangeEmailEmail: + """Service-level coverage for the phase-bound changes in `send_change_email_email`.""" + + def test_should_raise_value_error_for_invalid_phase(self): + with pytest.raises(ValueError, match="phase must be one of"): + AccountService.send_change_email_email( + email="user@example.com", + old_email="user@example.com", + phase="old_email_verified", + ) + + @patch("services.account_service.send_change_mail_task") + @patch("services.account_service.AccountService.change_email_rate_limiter") + @patch("services.account_service.AccountService.generate_change_email_token") + def test_should_stamp_phase_into_generated_token( + self, + mock_generate_token, + mock_rate_limiter, + mock_mail_task, + ): + mock_rate_limiter.is_rate_limited.return_value = False + mock_generate_token.return_value = ("123456", "the-token") + + returned = AccountService.send_change_email_email( + email="user@example.com", + old_email="user@example.com", + language="en-US", + phase=AccountService.CHANGE_EMAIL_PHASE_NEW, + ) + + assert returned == "the-token" + mock_generate_token.assert_called_once_with( + "user@example.com", + None, + old_email="user@example.com", + additional_data={ + AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW, + }, + ) + mock_mail_task.delay.assert_called_once() + mock_rate_limiter.increment_rate_limit.assert_called_once_with("user@example.com") + class TestAccountDeletionFeedback: @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback") - def test_should_normalize_feedback_email(self, mock_update, mock_db, app): - _mock_wraps_db(mock_db) + def test_should_normalize_feedback_email(self, mock_update, mock_db, app: Flask): with app.test_request_context( "/account/delete/feedback", method="POST", @@ -215,8 +563,7 @@ class TestCheckEmailUnique: @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") - def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app): - _mock_wraps_db(mock_db) + def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app: Flask): mock_is_freeze.return_value = False mock_check_unique.return_value = True @@ -233,15 +580,20 @@ class TestCheckEmailUnique: def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): - session = MagicMock() + mock_session = MagicMock() first = MagicMock() first.scalar_one_or_none.return_value = None second = MagicMock() expected_account = MagicMock() second.scalar_one_or_none.return_value = expected_account - session.execute.side_effect = [first, second] + mock_session.execute.side_effect = [first, second] - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=session) + mock_factory = MagicMock() + mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) + + with patch("services.account_service.session_factory", mock_factory): + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") assert result is expected_account - assert session.execute.call_count == 2 + assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py index 239fec8430..412d6a6c52 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_members.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -1,5 +1,5 @@ from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from flask import Flask, g @@ -16,10 +16,6 @@ def app(): return flask_app -def _mock_wraps_db(mock_db): - mock_db.session.query.return_value.first.return_value = MagicMock() - - def _build_feature_flags(): placeholder_quota = SimpleNamespace(limit=0, size=0) workspace_members = SimpleNamespace(is_available=lambda count: True) @@ -47,9 +43,8 @@ class TestMemberInviteEmailApi: mock_current_account, mock_invite_member, mock_get_features, - app, + app: Flask, ): - _mock_wraps_db(mock_db) mock_get_features.return_value = _build_feature_flags() mock_invite_member.return_value = "token-abc" diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index f6e096a97b..aa4973851a 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -310,7 +310,6 @@ class TestSystemSetup: def test_should_allow_when_setup_complete(self, mock_db): """Test that requests are allowed when setup is complete""" # Arrange - mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists @setup_required def admin_view(): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py index 42be02cdaf..064726da05 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock, PropertyMock, patch import pytest +from flask import Flask +from werkzeug.exceptions import NotFound from controllers.console import console_ns from controllers.console.auth.error import ( @@ -29,6 +31,7 @@ from controllers.console.workspace.error import ( CurrentPasswordIncorrectError, InvalidAccountDeletionCodeError, ) +from models.enums import CreatorUserRole from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError @@ -39,7 +42,7 @@ def unwrap(func): class TestAccountInitApi: - def test_init_success(self, app): + def test_init_success(self, app: Flask): api = AccountInitApi() method = unwrap(api.post) @@ -62,7 +65,7 @@ class TestAccountInitApi: assert resp["result"] == "success" - def test_init_already_initialized(self, app): + def test_init_already_initialized(self, app: Flask): api = AccountInitApi() method = unwrap(api.post) @@ -77,7 +80,7 @@ class TestAccountInitApi: class TestAccountProfileApi: - def test_get_profile_success(self, app): + def test_get_profile_success(self, app: Flask): api = AccountProfileApi() method = unwrap(api.get) @@ -135,8 +138,133 @@ class TestAccountUpdateApis: assert result["id"] == "u1" +class TestAccountAvatarApiGet: + """GET /account/avatar must not sign arbitrary upload_file IDs (IDOR).""" + + def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app: Flask): + api = AccountAvatarApi() + method = unwrap(api.get) + + user = MagicMock() + user.id = "acc-owner" + tenant_id = "tenant-1" + file_id = "550e8400-e29b-41d4-a716-446655440000" + + upload_file = MagicMock() + upload_file.id = file_id + upload_file.tenant_id = tenant_id + upload_file.created_by = user.id + upload_file.created_by_role = CreatorUserRole.ACCOUNT + + with ( + app.test_request_context(f"/account/avatar?avatar={file_id}"), + patch( + "controllers.console.workspace.account.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file), + patch( + "controllers.console.workspace.account.file_helpers.get_signed_file_url", + return_value="https://signed/example", + ) as sign_mock, + ): + result = method(api) + + assert result == {"avatar_url": "https://signed/example"} + sign_mock.assert_called_once_with(upload_file_id=file_id) + + def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app: Flask): + api = AccountAvatarApi() + method = unwrap(api.get) + + user = MagicMock() + user.id = "acc-a" + tenant_id = "tenant-1" + file_id = "550e8400-e29b-41d4-a716-446655440001" + + upload_file = MagicMock() + upload_file.id = file_id + upload_file.tenant_id = tenant_id + upload_file.created_by = "acc-b" + upload_file.created_by_role = CreatorUserRole.ACCOUNT + + with ( + app.test_request_context(f"/account/avatar?avatar={file_id}"), + patch( + "controllers.console.workspace.account.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file), + patch( + "controllers.console.workspace.account.file_helpers.get_signed_file_url", + return_value="https://signed/leak", + ) as sign_mock, + ): + with pytest.raises(NotFound): + method(api) + + sign_mock.assert_not_called() + + def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app: Flask): + api = AccountAvatarApi() + method = unwrap(api.get) + + user = MagicMock() + user.id = "acc-owner" + tenant_id = "tenant-1" + file_id = "550e8400-e29b-41d4-a716-446655440002" + + upload_file = MagicMock() + upload_file.id = file_id + upload_file.tenant_id = "tenant-other" + upload_file.created_by = user.id + upload_file.created_by_role = CreatorUserRole.ACCOUNT + + with ( + app.test_request_context(f"/account/avatar?avatar={file_id}"), + patch( + "controllers.console.workspace.account.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file), + patch( + "controllers.console.workspace.account.file_helpers.get_signed_file_url", + return_value="https://signed/leak", + ) as sign_mock, + ): + with pytest.raises(NotFound): + method(api) + + sign_mock.assert_not_called() + + def test_get_avatar_https_pass_through_without_signing(self, app: Flask): + api = AccountAvatarApi() + method = unwrap(api.get) + + user = MagicMock() + user.id = "acc-owner" + tenant_id = "tenant-1" + external = "https://cdn.example/avatar.png" + + with ( + app.test_request_context(f"/account/avatar?avatar={external}"), + patch( + "controllers.console.workspace.account.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.account.file_helpers.get_signed_file_url", + return_value="https://signed/should-not-use", + ) as sign_mock, + ): + result = method(api) + + assert result == {"avatar_url": external} + sign_mock.assert_not_called() + + class TestAccountPasswordApi: - def test_password_success(self, app): + def test_password_success(self, app: Flask): api = AccountPasswordApi() method = unwrap(api.post) @@ -165,7 +293,7 @@ class TestAccountPasswordApi: assert result["id"] == "u1" - def test_password_wrong_current(self, app): + def test_password_wrong_current(self, app: Flask): api = AccountPasswordApi() method = unwrap(api.post) @@ -190,7 +318,7 @@ class TestAccountPasswordApi: class TestAccountIntegrateApi: - def test_get_integrates(self, app): + def test_get_integrates(self, app: Flask): api = AccountIntegrateApi() method = unwrap(api.get) @@ -209,7 +337,7 @@ class TestAccountIntegrateApi: class TestAccountDeleteApi: - def test_delete_verify_success(self, app): + def test_delete_verify_success(self, app: Flask): api = AccountDeleteVerifyApi() method = unwrap(api.get) @@ -231,7 +359,7 @@ class TestAccountDeleteApi: assert result["result"] == "success" - def test_delete_invalid_code(self, app): + def test_delete_invalid_code(self, app: Flask): api = AccountDeleteApi() method = unwrap(api.post) @@ -252,7 +380,7 @@ class TestAccountDeleteApi: class TestChangeEmailApis: - def test_check_email_code_invalid(self, app): + def test_check_email_code_invalid(self, app: Flask): api = ChangeEmailCheckApi() method = unwrap(api.post) @@ -278,7 +406,7 @@ class TestChangeEmailApis: with pytest.raises(EmailCodeError): method(api) - def test_reset_email_already_used(self, app): + def test_reset_email_already_used(self, app: Flask): api = ChangeEmailResetApi() method = unwrap(api.post) @@ -300,7 +428,7 @@ class TestChangeEmailApis: class TestCheckEmailUniqueApi: - def test_email_unique_success(self, app): + def test_email_unique_success(self, app: Flask): api = CheckEmailUnique() method = unwrap(api.post) @@ -321,7 +449,7 @@ class TestCheckEmailUniqueApi: assert result["result"] == "success" - def test_email_in_freeze(self, app): + def test_email_in_freeze(self, app: Flask): api = CheckEmailUnique() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py index b4e03f681d..eb0ca15d2e 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.error import AccountNotFound from controllers.console.workspace.agent_providers import ( @@ -16,7 +17,7 @@ def unwrap(func): class TestAgentProviderListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = AgentProviderListApi() method = unwrap(api.get) @@ -39,7 +40,7 @@ class TestAgentProviderListApi: assert result == providers - def test_get_empty_list(self, app): + def test_get_empty_list(self, app: Flask): api = AgentProviderListApi() method = unwrap(api.get) @@ -61,7 +62,7 @@ class TestAgentProviderListApi: assert result == [] - def test_get_account_not_found(self, app): + def test_get_account_not_found(self, app: Flask): api = AgentProviderListApi() method = unwrap(api.get) @@ -77,7 +78,7 @@ class TestAgentProviderListApi: class TestAgentProviderApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = AgentProviderApi() method = unwrap(api.get) @@ -101,7 +102,7 @@ class TestAgentProviderApi: assert result == provider_data - def test_get_provider_not_found(self, app): + def test_get_provider_not_found(self, app: Flask): api = AgentProviderApi() method = unwrap(api.get) @@ -124,7 +125,7 @@ class TestAgentProviderApi: assert result is None - def test_get_account_not_found(self, app): + def test_get_account_not_found(self, app: Flask): api = AgentProviderApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py index 51f76af172..ed7b2d606f 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py @@ -1,15 +1,19 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask +from controllers.console import console_ns from controllers.console.workspace.endpoint import ( - EndpointCreateApi, - EndpointDeleteApi, + DeprecatedEndpointCreateApi, + DeprecatedEndpointDeleteApi, + DeprecatedEndpointUpdateApi, + EndpointCollectionApi, EndpointDisableApi, EndpointEnableApi, + EndpointItemApi, EndpointListApi, EndpointListForSinglePluginApi, - EndpointUpdateApi, ) from core.plugin.impl.exc import PluginPermissionDeniedError @@ -35,9 +39,9 @@ def patch_current_account(user_and_tenant): @pytest.mark.usefixtures("patch_current_account") -class TestEndpointCreateApi: - def test_create_success(self, app): - api = EndpointCreateApi() +class TestEndpointCollectionApi: + def test_create_success(self, app: Flask): + api = EndpointCollectionApi() method = unwrap(api.post) payload = { @@ -54,8 +58,8 @@ class TestEndpointCreateApi: assert result["success"] is True - def test_create_permission_denied(self, app): - api = EndpointCreateApi() + def test_create_permission_denied(self, app: Flask): + api = EndpointCollectionApi() method = unwrap(api.post) payload = { @@ -74,8 +78,8 @@ class TestEndpointCreateApi: with pytest.raises(ValueError): method(api) - def test_create_validation_error(self, app): - api = EndpointCreateApi() + def test_create_validation_error(self, app: Flask): + api = EndpointCollectionApi() method = unwrap(api.post) payload = { @@ -91,9 +95,30 @@ class TestEndpointCreateApi: method(api) +@pytest.mark.usefixtures("patch_current_account") +class TestDeprecatedEndpointCreateApi: + def test_create_success(self, app: Flask): + api = DeprecatedEndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "plugin-1", + "name": "endpoint", + "settings": {"a": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + @pytest.mark.usefixtures("patch_current_account") class TestEndpointListApi: - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = EndpointListApi() method = unwrap(api.get) @@ -106,7 +131,7 @@ class TestEndpointListApi: assert "endpoints" in result assert len(result["endpoints"]) == 1 - def test_list_invalid_query(self, app): + def test_list_invalid_query(self, app: Flask): api = EndpointListApi() method = unwrap(api.get) @@ -119,7 +144,7 @@ class TestEndpointListApi: @pytest.mark.usefixtures("patch_current_account") class TestEndpointListForSinglePluginApi: - def test_list_for_plugin_success(self, app): + def test_list_for_plugin_success(self, app: Flask): api = EndpointListForSinglePluginApi() method = unwrap(api.get) @@ -134,7 +159,7 @@ class TestEndpointListForSinglePluginApi: assert "endpoints" in result - def test_list_for_plugin_missing_param(self, app): + def test_list_for_plugin_missing_param(self, app: Flask): api = EndpointListForSinglePluginApi() method = unwrap(api.get) @@ -146,9 +171,96 @@ class TestEndpointListForSinglePluginApi: @pytest.mark.usefixtures("patch_current_account") -class TestEndpointDeleteApi: - def test_delete_success(self, app): - api = EndpointDeleteApi() +class TestEndpointItemApi: + def test_delete_success(self, app: Flask): + api = EndpointItemApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/", method="DELETE"), + patch( + "controllers.console.workspace.endpoint.EndpointService.delete_endpoint", + return_value=True, + ) as mock_delete, + ): + result = method(api, "e1") + + assert result["success"] is True + mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1") + + def test_delete_service_failure(self, app: Flask): + api = EndpointItemApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/", method="DELETE"), + patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False), + ): + result = method(api, "e1") + + assert result["success"] is False + + def test_update_success(self, app: Flask): + api = EndpointItemApi() + method = unwrap(api.patch) + + payload = { + "name": "new-name", + "settings": {"x": 1}, + } + + with ( + app.test_request_context("/", method="PATCH", json=payload), + patch( + "controllers.console.workspace.endpoint.EndpointService.update_endpoint", + return_value=True, + ) as mock_update, + ): + result = method(api, "e1") + + assert result["success"] is True + mock_update.assert_called_once_with( + tenant_id="t1", + user_id="u1", + endpoint_id="e1", + name="new-name", + settings={"x": 1}, + ) + + def test_update_validation_error(self, app: Flask): + api = EndpointItemApi() + method = unwrap(api.patch) + + payload = {"settings": {}} + + with ( + app.test_request_context("/", method="PATCH", json=payload), + ): + with pytest.raises(ValueError): + method(api, "e1") + + def test_update_service_failure(self, app: Flask): + api = EndpointItemApi() + method = unwrap(api.patch) + + payload = { + "name": "n", + "settings": {}, + } + + with ( + app.test_request_context("/", method="PATCH", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False), + ): + result = method(api, "e1") + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestDeprecatedEndpointDeleteApi: + def test_delete_success(self, app: Flask): + api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) payload = {"endpoint_id": "e1"} @@ -161,8 +273,8 @@ class TestEndpointDeleteApi: assert result["success"] is True - def test_delete_invalid_payload(self, app): - api = EndpointDeleteApi() + def test_delete_invalid_payload(self, app: Flask): + api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) with ( @@ -171,8 +283,8 @@ class TestEndpointDeleteApi: with pytest.raises(ValueError): method(api) - def test_delete_service_failure(self, app): - api = EndpointDeleteApi() + def test_delete_service_failure(self, app: Flask): + api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) payload = {"endpoint_id": "e1"} @@ -187,9 +299,9 @@ class TestEndpointDeleteApi: @pytest.mark.usefixtures("patch_current_account") -class TestEndpointUpdateApi: - def test_update_success(self, app): - api = EndpointUpdateApi() +class TestDeprecatedEndpointUpdateApi: + def test_update_success(self, app: Flask): + api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) payload = { @@ -206,8 +318,8 @@ class TestEndpointUpdateApi: assert result["success"] is True - def test_update_validation_error(self, app): - api = EndpointUpdateApi() + def test_update_validation_error(self, app: Flask): + api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) payload = {"endpoint_id": "e1", "settings": {}} @@ -218,8 +330,8 @@ class TestEndpointUpdateApi: with pytest.raises(ValueError): method(api) - def test_update_service_failure(self, app): - api = EndpointUpdateApi() + def test_update_service_failure(self, app: Flask): + api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) payload = { @@ -237,9 +349,39 @@ class TestEndpointUpdateApi: assert result["success"] is False +class TestEndpointRouteMetadata: + def test_legacy_write_routes_are_marked_deprecated(self): + assert DeprecatedEndpointCreateApi.post.__apidoc__["deprecated"] is True + assert DeprecatedEndpointDeleteApi.post.__apidoc__["deprecated"] is True + assert DeprecatedEndpointUpdateApi.post.__apidoc__["deprecated"] is True + assert EndpointCollectionApi.post.__apidoc__.get("deprecated") is not True + assert EndpointItemApi.delete.__apidoc__.get("deprecated") is not True + assert EndpointItemApi.patch.__apidoc__.get("deprecated") is not True + + def test_canonical_and_legacy_write_routes_are_registered(self): + route_map = { + resource.__name__: urls + for resource, urls, _route_doc, _kwargs in console_ns.resources + if resource.__name__ + in { + "EndpointCollectionApi", + "EndpointItemApi", + "DeprecatedEndpointCreateApi", + "DeprecatedEndpointDeleteApi", + "DeprecatedEndpointUpdateApi", + } + } + + assert route_map["EndpointCollectionApi"] == ("/workspaces/current/endpoints",) + assert route_map["EndpointItemApi"] == ("/workspaces/current/endpoints/",) + assert route_map["DeprecatedEndpointCreateApi"] == ("/workspaces/current/endpoints/create",) + assert route_map["DeprecatedEndpointDeleteApi"] == ("/workspaces/current/endpoints/delete",) + assert route_map["DeprecatedEndpointUpdateApi"] == ("/workspaces/current/endpoints/update",) + + @pytest.mark.usefixtures("patch_current_account") class TestEndpointEnableApi: - def test_enable_success(self, app): + def test_enable_success(self, app: Flask): api = EndpointEnableApi() method = unwrap(api.post) @@ -253,7 +395,7 @@ class TestEndpointEnableApi: assert result["success"] is True - def test_enable_invalid_payload(self, app): + def test_enable_invalid_payload(self, app: Flask): api = EndpointEnableApi() method = unwrap(api.post) @@ -263,7 +405,7 @@ class TestEndpointEnableApi: with pytest.raises(ValueError): method(api) - def test_enable_service_failure(self, app): + def test_enable_service_failure(self, app: Flask): api = EndpointEnableApi() method = unwrap(api.post) @@ -280,7 +422,7 @@ class TestEndpointEnableApi: @pytest.mark.usefixtures("patch_current_account") class TestEndpointDisableApi: - def test_disable_success(self, app): + def test_disable_success(self, app: Flask): api = EndpointDisableApi() method = unwrap(api.post) @@ -294,7 +436,7 @@ class TestEndpointDisableApi: assert result["success"] is True - def test_disable_invalid_payload(self, app): + def test_disable_invalid_payload(self, app: Flask): api = EndpointDisableApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index 9c42ee9529..b2f949c6e2 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,9 +11,10 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView +from werkzeug.exceptions import Forbidden + from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from werkzeug.exceptions import Forbidden if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py index 718b57ba6b..0788ff603c 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_members.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import HTTPException import services @@ -34,7 +35,7 @@ def unwrap(func): class TestMemberListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = MemberListApi() method = unwrap(api.get) @@ -59,7 +60,7 @@ class TestMemberListApi: assert status == 200 assert len(result["accounts"]) == 1 - def test_get_no_tenant(self, app): + def test_get_no_tenant(self, app: Flask): api = MemberListApi() method = unwrap(api.get) @@ -74,7 +75,7 @@ class TestMemberListApi: class TestMemberInviteEmailApi: - def test_invite_success(self, app): + def test_invite_success(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -101,7 +102,7 @@ class TestMemberInviteEmailApi: assert status == 201 assert result["result"] == "success" - def test_invite_limit_exceeded(self, app): + def test_invite_limit_exceeded(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -123,7 +124,7 @@ class TestMemberInviteEmailApi: with pytest.raises(WorkspaceMembersLimitExceeded): method(api) - def test_invite_already_member(self, app): + def test_invite_already_member(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -151,7 +152,7 @@ class TestMemberInviteEmailApi: assert result["invitation_results"][0]["status"] == "success" - def test_invite_invalid_role(self, app): + def test_invite_invalid_role(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -166,7 +167,7 @@ class TestMemberInviteEmailApi: assert status == 400 assert result["code"] == "invalid-role" - def test_invite_generic_exception(self, app): + def test_invite_generic_exception(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -196,7 +197,7 @@ class TestMemberInviteEmailApi: class TestMemberCancelInviteApi: - def test_cancel_success(self, app): + def test_cancel_success(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -216,7 +217,7 @@ class TestMemberCancelInviteApi: assert status == 200 assert result["result"] == "success" - def test_cancel_not_found(self, app): + def test_cancel_not_found(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -233,7 +234,7 @@ class TestMemberCancelInviteApi: with pytest.raises(HTTPException): method(api, "x") - def test_cancel_cannot_operate_self(self, app): + def test_cancel_cannot_operate_self(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -255,7 +256,7 @@ class TestMemberCancelInviteApi: assert status == 400 - def test_cancel_no_permission(self, app): + def test_cancel_no_permission(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -277,7 +278,7 @@ class TestMemberCancelInviteApi: assert status == 403 - def test_cancel_member_not_in_tenant(self, app): + def test_cancel_member_not_in_tenant(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -301,7 +302,7 @@ class TestMemberCancelInviteApi: class TestMemberUpdateRoleApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = MemberUpdateRoleApi() method = unwrap(api.put) @@ -324,7 +325,7 @@ class TestMemberUpdateRoleApi: assert result["result"] == "success" - def test_update_invalid_role(self, app): + def test_update_invalid_role(self, app: Flask): api = MemberUpdateRoleApi() method = unwrap(api.put) @@ -335,7 +336,7 @@ class TestMemberUpdateRoleApi: assert status == 400 - def test_update_member_not_found(self, app): + def test_update_member_not_found(self, app: Flask): api = MemberUpdateRoleApi() method = unwrap(api.put) @@ -354,7 +355,7 @@ class TestMemberUpdateRoleApi: class TestDatasetOperatorMemberListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetOperatorMemberListApi() method = unwrap(api.get) @@ -381,7 +382,7 @@ class TestDatasetOperatorMemberListApi: assert status == 200 assert len(result["accounts"]) == 1 - def test_get_no_tenant(self, app): + def test_get_no_tenant(self, app: Flask): api = DatasetOperatorMemberListApi() method = unwrap(api.get) @@ -396,7 +397,7 @@ class TestDatasetOperatorMemberListApi: class TestSendOwnerTransferEmailApi: - def test_send_success(self, app): + def test_send_success(self, app: Flask): api = SendOwnerTransferEmailApi() method = unwrap(api.post) @@ -419,7 +420,7 @@ class TestSendOwnerTransferEmailApi: assert result["result"] == "success" - def test_send_ip_limit(self, app): + def test_send_ip_limit(self, app: Flask): api = SendOwnerTransferEmailApi() method = unwrap(api.post) @@ -433,7 +434,7 @@ class TestSendOwnerTransferEmailApi: with pytest.raises(EmailSendIpLimitError): method(api) - def test_send_not_owner(self, app): + def test_send_not_owner(self, app: Flask): api = SendOwnerTransferEmailApi() method = unwrap(api.post) @@ -452,7 +453,7 @@ class TestSendOwnerTransferEmailApi: class TestOwnerTransferCheckApi: - def test_check_invalid_code(self, app): + def test_check_invalid_code(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -477,7 +478,7 @@ class TestOwnerTransferCheckApi: with pytest.raises(EmailCodeError): method(api) - def test_rate_limited(self, app): + def test_rate_limited(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -498,7 +499,7 @@ class TestOwnerTransferCheckApi: with pytest.raises(OwnerTransferLimitError): method(api) - def test_invalid_token(self, app): + def test_invalid_token(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -520,7 +521,7 @@ class TestOwnerTransferCheckApi: with pytest.raises(InvalidTokenError): method(api) - def test_invalid_email(self, app): + def test_invalid_email(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -547,7 +548,7 @@ class TestOwnerTransferCheckApi: class TestOwnerTransferApi: - def test_transfer_self(self, app): + def test_transfer_self(self, app: Flask): api = OwnerTransfer() method = unwrap(api.post) @@ -564,7 +565,7 @@ class TestOwnerTransferApi: with pytest.raises(CannotTransferOwnerToSelfError): method(api, "1") - def test_invalid_token(self, app): + def test_invalid_token(self, app: Flask): api = OwnerTransfer() method = unwrap(api.post) @@ -582,7 +583,7 @@ class TestOwnerTransferApi: with pytest.raises(InvalidTokenError): method(api, "2") - def test_member_not_in_tenant(self, app): + def test_member_not_in_tenant(self, app: Flask): api = OwnerTransfer() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index fb9eec98cb..e836a3cc55 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from flask import Flask from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -14,6 +14,7 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" @@ -26,7 +27,7 @@ def unwrap(func): class TestModelProviderListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ModelProviderListApi() method = unwrap(api.get) @@ -47,7 +48,7 @@ class TestModelProviderListApi: class TestModelProviderCredentialApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.get) @@ -66,7 +67,7 @@ class TestModelProviderCredentialApi: assert "credentials" in result - def test_get_invalid_uuid(self, app): + def test_get_invalid_uuid(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.get) @@ -80,7 +81,7 @@ class TestModelProviderCredentialApi: with pytest.raises(ValidationError): method(api, provider="openai") - def test_post_create_success(self, app): + def test_post_create_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.post) @@ -102,7 +103,7 @@ class TestModelProviderCredentialApi: assert result["result"] == "success" assert status == 201 - def test_post_create_validation_error(self, app): + def test_post_create_validation_error(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.post) @@ -122,7 +123,7 @@ class TestModelProviderCredentialApi: with pytest.raises(ValueError): method(api, provider="openai") - def test_put_update_success(self, app): + def test_put_update_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.put) @@ -143,7 +144,7 @@ class TestModelProviderCredentialApi: assert result["result"] == "success" - def test_put_invalid_uuid(self, app): + def test_put_invalid_uuid(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.put) @@ -159,7 +160,7 @@ class TestModelProviderCredentialApi: with pytest.raises(ValidationError): method(api, provider="openai") - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.delete) @@ -183,7 +184,7 @@ class TestModelProviderCredentialApi: class TestModelProviderCredentialSwitchApi: - def test_switch_success(self, app): + def test_switch_success(self, app: Flask): api = ModelProviderCredentialSwitchApi() method = unwrap(api.post) @@ -204,7 +205,7 @@ class TestModelProviderCredentialSwitchApi: assert result["result"] == "success" - def test_switch_invalid_uuid(self, app): + def test_switch_invalid_uuid(self, app: Flask): api = ModelProviderCredentialSwitchApi() method = unwrap(api.post) @@ -222,7 +223,7 @@ class TestModelProviderCredentialSwitchApi: class TestModelProviderValidateApi: - def test_validate_success(self, app): + def test_validate_success(self, app: Flask): api = ModelProviderValidateApi() method = unwrap(api.post) @@ -243,7 +244,7 @@ class TestModelProviderValidateApi: assert result["result"] == "success" - def test_validate_failure(self, app): + def test_validate_failure(self, app: Flask): api = ModelProviderValidateApi() method = unwrap(api.post) @@ -266,7 +267,7 @@ class TestModelProviderValidateApi: class TestModelProviderIconApi: - def test_icon_success(self, app): + def test_icon_success(self, app: Flask): api = ModelProviderIconApi() with ( @@ -280,7 +281,7 @@ class TestModelProviderIconApi: assert response.mimetype == "image/png" - def test_icon_not_found(self, app): + def test_icon_not_found(self, app: Flask): api = ModelProviderIconApi() with ( @@ -295,7 +296,7 @@ class TestModelProviderIconApi: class TestPreferredProviderTypeUpdateApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = PreferredProviderTypeUpdateApi() method = unwrap(api.post) @@ -316,7 +317,7 @@ class TestPreferredProviderTypeUpdateApi: assert result["result"] == "success" - def test_invalid_enum(self, app): + def test_invalid_enum(self, app: Flask): api = PreferredProviderTypeUpdateApi() method = unwrap(api.post) @@ -334,7 +335,7 @@ class TestPreferredProviderTypeUpdateApi: class TestModelProviderPaymentCheckoutUrlApi: - def test_checkout_success(self, app): + def test_checkout_success(self, app: Flask): api = ModelProviderPaymentCheckoutUrlApi() method = unwrap(api.get) @@ -359,7 +360,7 @@ class TestModelProviderPaymentCheckoutUrlApi: assert "url" in result - def test_invalid_provider(self, app): + def test_invalid_provider(self, app: Flask): api = ModelProviderPaymentCheckoutUrlApi() method = unwrap(api.get) @@ -367,7 +368,7 @@ class TestModelProviderPaymentCheckoutUrlApi: with pytest.raises(ValueError): method(api, provider="openai") - def test_permission_denied(self, app): + def test_permission_denied(self, app: Flask): api = ModelProviderPaymentCheckoutUrlApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index c829327bc7..4246e3c04c 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -2,8 +2,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from controllers.console.workspace.models import ( DefaultModelApi, @@ -16,6 +14,8 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): @@ -72,7 +72,7 @@ class TestDefaultModelApi: assert result["result"] == "success" - def test_get_returns_empty_when_no_default(self, app): + def test_get_returns_empty_when_no_default(self, app: Flask): api = DefaultModelApi() method = unwrap(api.get) @@ -154,7 +154,7 @@ class TestModelProviderModelApi: assert status == 204 - def test_get_models_returns_empty(self, app): + def test_get_models_returns_empty(self, app: Flask): api = ModelProviderModelApi() method = unwrap(api.get) @@ -224,7 +224,7 @@ class TestModelProviderModelCredentialApi: assert status == 201 - def test_get_empty_credentials(self, app): + def test_get_empty_credentials(self, app: Flask): api = ModelProviderModelCredentialApi() method = unwrap(api.get) @@ -242,7 +242,7 @@ class TestModelProviderModelCredentialApi: assert result["credentials"] == {} - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = ModelProviderModelCredentialApi() method = unwrap(api.delete) @@ -416,7 +416,7 @@ class TestParameterAndAvailableModels: assert "data" in result - def test_empty_rules(self, app): + def test_empty_rules(self, app: Flask): api = ModelProviderModelParameterRuleApi() method = unwrap(api.get) @@ -431,7 +431,7 @@ class TestParameterAndAvailableModels: assert result["data"] == [] - def test_no_models(self, app): + def test_no_models(self, app: Flask): api = ModelProviderAvailableModelApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py index ce5fd1c466..d01bf7d668 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -2,6 +2,7 @@ import io from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -61,7 +62,7 @@ def tenant(): class TestPluginListLatestVersionsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginListLatestVersionsApi() method = unwrap(api.post) @@ -77,7 +78,7 @@ class TestPluginListLatestVersionsApi: assert "versions" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginListLatestVersionsApi() method = unwrap(api.post) @@ -95,7 +96,7 @@ class TestPluginListLatestVersionsApi: class TestPluginDebuggingKeyApi: - def test_debugging_key_success(self, app): + def test_debugging_key_success(self, app: Flask): api = PluginDebuggingKeyApi() method = unwrap(api.get) @@ -108,7 +109,7 @@ class TestPluginDebuggingKeyApi: assert result["key"] == "k" - def test_debugging_key_error(self, app): + def test_debugging_key_error(self, app: Flask): api = PluginDebuggingKeyApi() method = unwrap(api.get) @@ -125,7 +126,7 @@ class TestPluginDebuggingKeyApi: class TestPluginListApi: - def test_plugin_list(self, app): + def test_plugin_list(self, app: Flask): api = PluginListApi() method = unwrap(api.get) @@ -142,7 +143,7 @@ class TestPluginListApi: class TestPluginIconApi: - def test_plugin_icon(self, app): + def test_plugin_icon(self, app: Flask): api = PluginIconApi() method = unwrap(api.get) @@ -156,7 +157,7 @@ class TestPluginIconApi: class TestPluginAssetApi: - def test_plugin_asset(self, app): + def test_plugin_asset(self, app: Flask): api = PluginAssetApi() method = unwrap(api.get) @@ -171,7 +172,7 @@ class TestPluginAssetApi: class TestPluginUploadFromPkgApi: - def test_upload_pkg_success(self, app): + def test_upload_pkg_success(self, app: Flask): api = PluginUploadFromPkgApi() method = unwrap(api.post) @@ -188,7 +189,7 @@ class TestPluginUploadFromPkgApi: assert result["ok"] is True - def test_upload_pkg_too_large(self, app): + def test_upload_pkg_too_large(self, app: Flask): api = PluginUploadFromPkgApi() method = unwrap(api.post) @@ -210,7 +211,7 @@ class TestPluginUploadFromPkgApi: class TestPluginInstallFromPkgApi: - def test_install_from_pkg(self, app): + def test_install_from_pkg(self, app: Flask): api = PluginInstallFromPkgApi() method = unwrap(api.post) @@ -229,7 +230,7 @@ class TestPluginInstallFromPkgApi: class TestPluginUninstallApi: - def test_uninstall(self, app): + def test_uninstall(self, app: Flask): api = PluginUninstallApi() method = unwrap(api.post) @@ -246,7 +247,7 @@ class TestPluginUninstallApi: class TestPluginChangePermissionApi: - def test_change_permission_forbidden(self, app): + def test_change_permission_forbidden(self, app: Flask): api = PluginChangePermissionApi() method = unwrap(api.post) @@ -264,7 +265,7 @@ class TestPluginChangePermissionApi: with pytest.raises(Forbidden): method(api) - def test_change_permission_success(self, app): + def test_change_permission_success(self, app: Flask): api = PluginChangePermissionApi() method = unwrap(api.post) @@ -286,7 +287,7 @@ class TestPluginChangePermissionApi: class TestPluginFetchPermissionApi: - def test_fetch_permission_default(self, app): + def test_fetch_permission_default(self, app: Flask): api = PluginFetchPermissionApi() method = unwrap(api.get) @@ -319,7 +320,7 @@ class TestPluginFetchDynamicSelectOptionsApi: class TestPluginReadmeApi: - def test_fetch_readme(self, app): + def test_fetch_readme(self, app: Flask): api = PluginReadmeApi() method = unwrap(api.get) @@ -334,7 +335,7 @@ class TestPluginReadmeApi: class TestPluginListInstallationsFromIdsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginListInstallationsFromIdsApi() method = unwrap(api.post) @@ -352,7 +353,7 @@ class TestPluginListInstallationsFromIdsApi: assert "plugins" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginListInstallationsFromIdsApi() method = unwrap(api.post) @@ -371,7 +372,7 @@ class TestPluginListInstallationsFromIdsApi: class TestPluginUploadFromGithubApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUploadFromGithubApi() method = unwrap(api.post) @@ -388,7 +389,7 @@ class TestPluginUploadFromGithubApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginUploadFromGithubApi() method = unwrap(api.post) @@ -407,7 +408,7 @@ class TestPluginUploadFromGithubApi: class TestPluginUploadFromBundleApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUploadFromBundleApi() method = unwrap(api.post) @@ -430,7 +431,7 @@ class TestPluginUploadFromBundleApi: assert result["ok"] is True - def test_too_large(self, app): + def test_too_large(self, app: Flask): api = PluginUploadFromBundleApi() method = unwrap(api.post) @@ -458,7 +459,7 @@ class TestPluginUploadFromBundleApi: class TestPluginInstallFromGithubApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginInstallFromGithubApi() method = unwrap(api.post) @@ -478,7 +479,7 @@ class TestPluginInstallFromGithubApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginInstallFromGithubApi() method = unwrap(api.post) @@ -502,7 +503,7 @@ class TestPluginInstallFromGithubApi: class TestPluginInstallFromMarketplaceApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginInstallFromMarketplaceApi() method = unwrap(api.post) @@ -520,7 +521,7 @@ class TestPluginInstallFromMarketplaceApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginInstallFromMarketplaceApi() method = unwrap(api.post) @@ -539,7 +540,7 @@ class TestPluginInstallFromMarketplaceApi: class TestPluginFetchMarketplacePkgApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchMarketplacePkgApi() method = unwrap(api.get) @@ -552,7 +553,7 @@ class TestPluginFetchMarketplacePkgApi: assert "manifest" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchMarketplacePkgApi() method = unwrap(api.get) @@ -569,7 +570,7 @@ class TestPluginFetchMarketplacePkgApi: class TestPluginFetchManifestApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchManifestApi() method = unwrap(api.get) @@ -585,7 +586,7 @@ class TestPluginFetchManifestApi: assert "manifest" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchManifestApi() method = unwrap(api.get) @@ -602,7 +603,7 @@ class TestPluginFetchManifestApi: class TestPluginFetchInstallTasksApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchInstallTasksApi() method = unwrap(api.get) @@ -615,7 +616,7 @@ class TestPluginFetchInstallTasksApi: assert "tasks" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchInstallTasksApi() method = unwrap(api.get) @@ -632,7 +633,7 @@ class TestPluginFetchInstallTasksApi: class TestPluginFetchInstallTaskApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchInstallTaskApi() method = unwrap(api.get) @@ -645,7 +646,7 @@ class TestPluginFetchInstallTaskApi: assert "task" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchInstallTaskApi() method = unwrap(api.get) @@ -662,7 +663,7 @@ class TestPluginFetchInstallTaskApi: class TestPluginDeleteInstallTaskApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginDeleteInstallTaskApi() method = unwrap(api.post) @@ -675,7 +676,7 @@ class TestPluginDeleteInstallTaskApi: assert result["success"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginDeleteInstallTaskApi() method = unwrap(api.post) @@ -692,7 +693,7 @@ class TestPluginDeleteInstallTaskApi: class TestPluginDeleteAllInstallTaskItemsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginDeleteAllInstallTaskItemsApi() method = unwrap(api.post) @@ -707,7 +708,7 @@ class TestPluginDeleteAllInstallTaskItemsApi: assert result["success"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginDeleteAllInstallTaskItemsApi() method = unwrap(api.post) @@ -724,7 +725,7 @@ class TestPluginDeleteAllInstallTaskItemsApi: class TestPluginDeleteInstallTaskItemApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginDeleteInstallTaskItemApi() method = unwrap(api.post) @@ -737,7 +738,7 @@ class TestPluginDeleteInstallTaskItemApi: assert result["success"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginDeleteInstallTaskItemApi() method = unwrap(api.post) @@ -754,7 +755,7 @@ class TestPluginDeleteInstallTaskItemApi: class TestPluginUpgradeFromMarketplaceApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUpgradeFromMarketplaceApi() method = unwrap(api.post) @@ -775,7 +776,7 @@ class TestPluginUpgradeFromMarketplaceApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginUpgradeFromMarketplaceApi() method = unwrap(api.post) @@ -797,7 +798,7 @@ class TestPluginUpgradeFromMarketplaceApi: class TestPluginUpgradeFromGithubApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUpgradeFromGithubApi() method = unwrap(api.post) @@ -821,7 +822,7 @@ class TestPluginUpgradeFromGithubApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginUpgradeFromGithubApi() method = unwrap(api.post) @@ -846,7 +847,7 @@ class TestPluginUpgradeFromGithubApi: class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchDynamicSelectOptionsWithCredentialsApi() method = unwrap(api.post) @@ -873,7 +874,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: assert result["options"] == [1] - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchDynamicSelectOptionsWithCredentialsApi() method = unwrap(api.post) @@ -901,7 +902,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: class TestPluginChangePreferencesApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginChangePreferencesApi() method = unwrap(api.post) @@ -931,7 +932,7 @@ class TestPluginChangePreferencesApi: assert result["success"] is True - def test_permission_fail(self, app): + def test_permission_fail(self, app: Flask): api = PluginChangePreferencesApi() method = unwrap(api.post) @@ -962,7 +963,7 @@ class TestPluginChangePreferencesApi: class TestPluginFetchPreferencesApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchPreferencesApi() method = unwrap(api.get) @@ -996,7 +997,7 @@ class TestPluginFetchPreferencesApi: class TestPluginAutoUpgradeExcludePluginApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginAutoUpgradeExcludePluginApi() method = unwrap(api.post) @@ -1011,7 +1012,7 @@ class TestPluginAutoUpgradeExcludePluginApi: assert result["success"] is True - def test_fail(self, app): + def test_fail(self, app: Flask): api = PluginAutoUpgradeExcludePluginApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py index 44feacf2ad..1422f29849 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -22,7 +22,7 @@ _WRAPS_MODULE: ModuleType | None = None @contextmanager def _mock_db(): - mock_session = SimpleNamespace(query=lambda *args, **kwargs: SimpleNamespace(first=lambda: True)) + mock_session = SimpleNamespace(scalar=lambda *args, **kwargs: True) with patch("extensions.ext_database.db.session", mock_session): yield diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index b2d13dbbdf..a52518c2d2 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -2,6 +2,7 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Unauthorized @@ -18,6 +19,7 @@ from controllers.console.workspace.workspace import ( CustomConfigWorkspaceApi, SwitchWorkspaceApi, TenantApi, + TenantInfoResponse, TenantListApi, WebappLogoWorkspaceApi, WorkspaceInfoApi, @@ -36,7 +38,7 @@ def unwrap(func): class TestTenantListApi: - def test_get_success_saas_path(self, app): + def test_get_success_saas_path(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -84,7 +86,7 @@ class TestTenantListApi: get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) get_features_mock.assert_not_called() - def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app): + def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app: Flask): """Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used. billing.enabled is mocked False to prove the endpoint does not gate on it for this path @@ -139,7 +141,7 @@ class TestTenantListApi: get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) get_features_mock.assert_called_once_with("t2") - def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app): + def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app: Flask): """Test fallback to FeatureService when bulk billing returns empty result. BillingService.get_plan_bulk catches exceptions internally and returns empty dict, @@ -196,7 +198,7 @@ class TestTenantListApi: assert get_features_mock.call_count == 2 logger_warning_mock.assert_called_once() - def test_get_billing_disabled_community_path(self, app): + def test_get_billing_disabled_community_path(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -235,7 +237,7 @@ class TestTenantListApi: assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX get_features_mock.assert_called_once_with("t1") - def test_get_enterprise_only_skips_feature_service(self, app): + def test_get_enterprise_only_skips_feature_service(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -275,7 +277,7 @@ class TestTenantListApi: assert result["workspaces"][1]["current"] is True get_features_mock.assert_not_called() - def test_get_enterprise_only_with_empty_tenants(self, app): + def test_get_enterprise_only_with_empty_tenants(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -301,7 +303,7 @@ class TestTenantListApi: class TestWorkspaceListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = WorkspaceListApi() method = unwrap(api.get) @@ -323,7 +325,7 @@ class TestWorkspaceListApi: assert result["total"] == 1 assert result["has_more"] is False - def test_get_has_next_true(self, app): + def test_get_has_next_true(self, app: Flask): api = WorkspaceListApi() method = unwrap(api.get) @@ -354,7 +356,7 @@ class TestWorkspaceListApi: class TestTenantApi: - def test_post_active_tenant(self, app): + def test_post_active_tenant(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -374,7 +376,7 @@ class TestTenantApi: assert status == 200 assert result["id"] == "t1" - def test_post_archived_with_switch(self, app): + def test_post_archived_with_switch(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -396,7 +398,7 @@ class TestTenantApi: assert result["id"] == "new" - def test_post_archived_no_tenant(self, app): + def test_post_archived_no_tenant(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -410,7 +412,7 @@ class TestTenantApi: with pytest.raises(Unauthorized): method(api) - def test_post_info_path(self, app): + def test_post_info_path(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -435,8 +437,25 @@ class TestTenantApi: assert status == 200 +class TestTenantInfoResponse: + def test_tenant_info_response_normalizes_enum_and_datetime(self): + created_at = naive_utc_now() + payload = TenantInfoResponse.model_validate( + { + "id": "t1", + "status": TenantStatus.NORMAL, + "plan": CloudPlan.TEAM, + "created_at": created_at, + } + ).model_dump(mode="json") + + assert payload["status"] == "normal" + assert payload["plan"] == "team" + assert payload["created_at"] == int(created_at.timestamp()) + + class TestSwitchWorkspaceApi: - def test_switch_success(self, app): + def test_switch_success(self, app: Flask): api = SwitchWorkspaceApi() method = unwrap(api.post) @@ -459,7 +478,7 @@ class TestSwitchWorkspaceApi: assert result["result"] == "success" - def test_switch_not_linked(self, app): + def test_switch_not_linked(self, app: Flask): api = SwitchWorkspaceApi() method = unwrap(api.post) @@ -475,7 +494,7 @@ class TestSwitchWorkspaceApi: with pytest.raises(AccountNotLinkTenantError): method(api) - def test_switch_tenant_not_found(self, app): + def test_switch_tenant_not_found(self, app: Flask): api = SwitchWorkspaceApi() method = unwrap(api.post) @@ -497,7 +516,7 @@ class TestSwitchWorkspaceApi: class TestCustomConfigWorkspaceApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CustomConfigWorkspaceApi() method = unwrap(api.post) @@ -520,7 +539,7 @@ class TestCustomConfigWorkspaceApi: assert result["result"] == "success" - def test_logo_fallback(self, app): + def test_logo_fallback(self, app: Flask): api = CustomConfigWorkspaceApi() method = unwrap(api.post) @@ -551,7 +570,7 @@ class TestCustomConfigWorkspaceApi: class TestWebappLogoWorkspaceApi: - def test_no_file(self, app): + def test_no_file(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -564,7 +583,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(NoFileUploadedError): method(api) - def test_too_many_files(self, app): + def test_too_many_files(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -583,7 +602,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(TooManyFilesError): method(api) - def test_invalid_extension(self, app): + def test_invalid_extension(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -598,7 +617,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(UnsupportedFileTypeError): method(api) - def test_upload_success(self, app): + def test_upload_success(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -630,7 +649,7 @@ class TestWebappLogoWorkspaceApi: assert status == 201 assert result["id"] == "file1" - def test_filename_missing(self, app): + def test_filename_missing(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -654,7 +673,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(FilenameNotExistsError): method(api) - def test_file_too_large(self, app): + def test_file_too_large(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -683,7 +702,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(FileTooLargeError): method(api) - def test_service_unsupported_file(self, app): + def test_service_unsupported_file(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -714,7 +733,7 @@ class TestWebappLogoWorkspaceApi: class TestWorkspaceInfoApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = WorkspaceInfoApi() method = unwrap(api.post) @@ -738,7 +757,7 @@ class TestWorkspaceInfoApi: assert result["result"] == "success" - def test_no_current_tenant(self, app): + def test_no_current_tenant(self, app: Flask): api = WorkspaceInfoApi() method = unwrap(api.post) @@ -756,7 +775,7 @@ class TestWorkspaceInfoApi: class TestWorkspacePermissionApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = WorkspacePermissionApi() method = unwrap(api.get) @@ -781,7 +800,7 @@ class TestWorkspacePermissionApi: assert status == 200 assert result["workspace_id"] == "t1" - def test_no_current_tenant(self, app): + def test_no_current_tenant(self, app: Flask): api = WorkspacePermissionApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py index 974d8f7bc6..71381e6a2b 100644 --- a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -18,6 +18,7 @@ from controllers.inner_api.app.dsl import ( InnerAppDSLImportPayload, _get_active_account, ) +from models.account import AccountStatus from services.app_dsl_service import ImportStatus @@ -63,7 +64,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_active_account(self, mock_db): mock_account = MagicMock() - mock_account.status = "active" + mock_account.status = AccountStatus.ACTIVE mock_db.session.scalar.return_value = mock_account result = _get_active_account("user@example.com") @@ -74,7 +75,7 @@ class TestGetActiveAccount: @patch("controllers.inner_api.app.dsl.db") def test_returns_none_for_inactive_account(self, mock_db): mock_account = MagicMock() - mock_account.status = "banned" + mock_account.status = AccountStatus.BANNED mock_db.session.scalar.return_value = mock_account result = _get_active_account("banned@example.com") @@ -102,16 +103,16 @@ class TestEnterpriseAppDSLImport: @pytest.fixture def _mock_import_deps(self): - """Patch db, sessionmaker, and AppDslService for import handler tests.""" - mock_session_ctx = MagicMock() - mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock()) - mock_session_ctx.__exit__ = MagicMock(return_value=False) - mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx))) + """Patch db, Session, and AppDslService for import handler tests.""" + mock_session = MagicMock() + mock_session.__enter__ = MagicMock(return_value=mock_session) + mock_session.__exit__ = MagicMock(return_value=False) with ( patch("controllers.inner_api.app.dsl.db"), - patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker), + patch("controllers.inner_api.app.dsl.Session", return_value=mock_session), patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, ): + self._mock_session = mock_session self._mock_dsl = MagicMock() mock_dsl_cls.return_value = self._mock_dsl yield @@ -147,6 +148,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 200 assert body["status"] == "completed" mock_account.set_tenant_id.assert_called_once_with("ws-123") + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -162,6 +165,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 202 assert body["status"] == "pending" + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -177,6 +182,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 400 assert body["status"] == "failed" + self._mock_session.rollback.assert_called_once_with() + self._mock_session.commit.assert_not_called() @patch("controllers.inner_api.app.dsl._get_active_account") def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask): diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py index 957d7fbd9b..d1b09c3a58 100644 --- a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -2,6 +2,7 @@ Unit tests for inner_api plugin decorators """ +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -40,17 +41,22 @@ class TestTenantUserPayload: class TestGetUser: """Test get_user function""" + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_return_existing_user_by_id(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): + def test_should_return_existing_user_by_id( + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask + ): """Test returning existing user when found by ID""" # Arrange mock_user = MagicMock() mock_user.id = "user123" mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.return_value = mock_user + mock_session.scalar.return_value = mock_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -58,13 +64,45 @@ class TestGetUser: # Assert assert result == mock_user - mock_session.get.assert_called_once() + mock_session.scalar.assert_called_once() + @patch("controllers.inner_api.plugin.wraps.select") + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.sessionmaker") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_not_resolve_non_anonymous_users_across_tenants( + self, + mock_db, + mock_sessionmaker, + mock_enduser_class, + mock_select, + app: Flask, + ): + """Test that explicit user IDs remain scoped to the current tenant.""" + # Arrange + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + mock_new_user = MagicMock() + mock_new_user.tenant_id = "tenant-current" + mock_enduser_class.return_value = mock_new_user + + # Act + with app.app_context(): + result = get_user("tenant-current", "foreign-user-id") + + # Assert + assert result == mock_new_user + mock_session.get.assert_not_called() + mock_session.scalar.assert_called_once() + mock_session.add.assert_called_once_with(mock_new_user) + + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") def test_should_return_existing_anonymous_user_by_session_id( - self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask ): """Test returning existing anonymous user by session_id""" # Arrange @@ -72,8 +110,9 @@ class TestGetUser: mock_user.session_id = "anonymous_session" mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - # non-anonymous path uses session.get(); anonymous uses session.scalar() - mock_session.get.return_value = mock_user + mock_session.scalar.return_value = mock_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -82,17 +121,22 @@ class TestGetUser: # Assert assert result == mock_user + @patch("controllers.inner_api.plugin.wraps.select") @patch("controllers.inner_api.plugin.wraps.EndUser") @patch("controllers.inner_api.plugin.wraps.sessionmaker") @patch("controllers.inner_api.plugin.wraps.db") - def test_should_create_new_user_when_not_found(self, mock_db, mock_sessionmaker, mock_enduser_class, app: Flask): + def test_should_create_new_user_when_not_found( + self, mock_db, mock_sessionmaker, mock_enduser_class, mock_select, app: Flask + ): """Test creating new user when not found in database""" # Arrange mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.return_value = None + mock_session.scalar.return_value = None mock_new_user = MagicMock() mock_enduser_class.return_value = mock_new_user + mock_query = MagicMock() + mock_select.return_value.where.return_value.limit.return_value = mock_query # Act with app.app_context(): @@ -133,7 +177,7 @@ class TestGetUser: # Arrange mock_session = MagicMock() mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.get.side_effect = Exception("Database error") + mock_session.scalar.side_effect = Exception("Database error") # Act & Assert with app.app_context(): @@ -232,11 +276,11 @@ class TestGetUserTenant: class PluginTestPayload: """Simple test payload class""" - def __init__(self, data: dict): + def __init__(self, data: dict[str, Any]): self.value = data.get("value") @classmethod - def model_validate(cls, data: dict): + def model_validate(cls, data: dict[str, Any]): return cls(data) @@ -277,7 +321,7 @@ class TestPluginData: # Arrange class InvalidPayload: @classmethod - def model_validate(cls, data: dict): + def model_validate(cls, data: dict[str, Any]): raise Exception("Validation failed") @plugin_data(payload_type=InvalidPayload) diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py index 56a8f94963..7d2193adc6 100644 --- a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -20,6 +20,7 @@ from controllers.inner_api.workspace.workspace import ( WorkspaceCreatePayload, WorkspaceOwnerlessPayload, ) +from models.account import TenantStatus class TestWorkspaceCreatePayload: @@ -98,7 +99,7 @@ class TestEnterpriseWorkspace: mock_tenant.id = "tenant-id" mock_tenant.name = "My Workspace" mock_tenant.plan = "sandbox" - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_tenant.created_at = now mock_tenant.updated_at = now mock_tenant_svc.create_tenant.return_value = mock_tenant @@ -162,7 +163,7 @@ class TestEnterpriseWorkspaceNoOwnerEmail: mock_tenant.name = "My Workspace" mock_tenant.encrypt_public_key = "pub-key" mock_tenant.plan = "sandbox" - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_tenant.custom_config = None mock_tenant.created_at = now mock_tenant.updated_at = now diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index 1507bf7a5f..ae0edcf382 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -10,8 +10,9 @@ from flask import Flask from controllers.service_api.app.app import AppInfoApi, AppMetaApi, AppParameterApi from controllers.service_api.app.error import AppUnavailableError +from models.account import TenantStatus from models.model import App, AppMode -from tests.unit_tests.conftest import setup_mock_tenant_account_query +from tests.unit_tests.conftest import setup_mock_tenant_owner_execute_result class TestAppParameterApi: @@ -40,7 +41,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_for_chat_app( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test retrieving parameters for a chat app.""" # Arrange @@ -62,7 +63,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL # Mock DB queries for app and tenant mock_db.session.get.side_effect = [ @@ -73,7 +74,7 @@ class TestAppParameterApi: # Mock tenant owner info for login mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -90,7 +91,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_for_workflow_app( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test retrieving parameters for a workflow app.""" # Arrange @@ -110,7 +111,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -119,7 +120,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -135,7 +136,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_raises_error_when_chat_config_missing( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test that AppUnavailableError is raised when chat app has no config.""" # Arrange @@ -151,7 +152,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -160,7 +161,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act & Assert with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -173,7 +174,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_raises_error_when_workflow_missing( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test that AppUnavailableError is raised when workflow app has no workflow.""" # Arrange @@ -190,7 +191,7 @@ class TestAppParameterApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -199,7 +200,7 @@ class TestAppParameterApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act & Assert with app.test_request_context("/parameters", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -233,7 +234,14 @@ class TestAppMetaApi: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.app.app.AppService") def test_get_app_meta( - self, mock_app_service, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, + mock_app_service, + mock_db, + mock_validate_token, + mock_current_app, + mock_user_logged_in, + app: Flask, + mock_app_model, ): """Test retrieving app metadata via AppService.""" # Arrange @@ -253,7 +261,7 @@ class TestAppMetaApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -262,7 +270,7 @@ class TestAppMetaApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/meta", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -309,7 +317,7 @@ class TestAppInfoApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_app_info( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test retrieving basic app information.""" mock_current_app.login_manager = Mock() @@ -321,7 +329,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app_model, @@ -330,7 +338,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -378,7 +386,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, @@ -387,7 +395,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -401,7 +409,9 @@ class TestAppInfoApi: @patch("controllers.service_api.wraps.current_app") @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") - def test_get_app_info_with_no_tags(self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app): + def test_get_app_info_with_no_tags( + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask + ): """Test retrieving app info when app has no tags.""" # Arrange mock_current_app.login_manager = Mock() @@ -424,7 +434,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, @@ -433,7 +443,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -452,7 +462,7 @@ class TestAppInfoApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_app_info_returns_correct_mode( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, app_mode + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, app_mode ): """Test that all app modes are correctly returned.""" # Arrange @@ -476,7 +486,7 @@ class TestAppInfoApi: mock_validate_token.return_value = mock_api_token mock_tenant = Mock() - mock_tenant.status = "normal" + mock_tenant.status = TenantStatus.NORMAL mock_db.session.get.side_effect = [ mock_app, @@ -485,7 +495,7 @@ class TestAppInfoApi: mock_account = Mock() mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) # Act with app.test_request_context("/info", method="GET", headers={"Authorization": "Bearer test_token"}): diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 5a8cb4619f..4741481ef6 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,7 +13,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -30,6 +30,7 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( @@ -95,30 +96,6 @@ class TestTextToAudioPayload: assert payload.streaming is True -# --------------------------------------------------------------------------- -# AudioService Interface Tests -# --------------------------------------------------------------------------- - - -class TestAudioServiceInterface: - """Test AudioService method interfaces exist.""" - - def test_transcript_asr_method_exists(self): - """Test that AudioService.transcript_asr exists.""" - assert hasattr(AudioService, "transcript_asr") - assert callable(AudioService.transcript_asr) - - def test_transcript_tts_method_exists(self): - """Test that AudioService.transcript_tts exists.""" - assert hasattr(AudioService, "transcript_tts") - assert callable(AudioService.transcript_tts) - - -# --------------------------------------------------------------------------- -# Audio Service Tests -# --------------------------------------------------------------------------- - - class TestAudioServiceInterface: """Test suite for AudioService interface methods.""" @@ -214,7 +191,7 @@ class TestAudioServiceMockedBehavior: class TestAudioApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"}) api = AudioApi() handler = _unwrap(api.post) @@ -240,7 +217,7 @@ class TestAudioApi: (InvokeError("invoke"), CompletionRequestError), ], ) - def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: + def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc)) api = AudioApi() handler = _unwrap(api.post) @@ -251,7 +228,7 @@ class TestAudioApi: with pytest.raises(expected): handler(api, app_model=app_model, end_user=end_user) - def test_unhandled_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_unhandled_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")) ) @@ -266,7 +243,7 @@ class TestAudioApi: class TestTextApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) api = TextApi() @@ -283,7 +260,7 @@ class TestTextApi: assert response == {"audio": "ok"} - def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()) ) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 57681d8f5b..259741937f 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,7 +16,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeError +from flask import Flask from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -35,6 +35,7 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService @@ -295,7 +296,7 @@ class TestCompletionControllerLogic: @patch("controllers.service_api.app.completion.service_api_ns") @patch("controllers.service_api.app.completion.AppGenerateService") - def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app): + def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask): """Test CompletionApi.post success path.""" from controllers.service_api.app.completion import CompletionApi @@ -320,7 +321,7 @@ class TestCompletionControllerLogic: mock_generate_service.generate.assert_called_once() @patch("controllers.service_api.app.completion.service_api_ns") - def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app): + def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask): """Test CompletionApi.post with wrong app mode.""" from controllers.service_api.app.completion import CompletionApi @@ -334,7 +335,7 @@ class TestCompletionControllerLogic: @patch("controllers.service_api.app.completion.service_api_ns") @patch("controllers.service_api.app.completion.AppGenerateService") - def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app): + def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask): """Test ChatApi.post success path.""" from controllers.service_api.app.completion import ChatApi @@ -355,7 +356,7 @@ class TestCompletionControllerLogic: assert response == {"text": "compacted"} @patch("controllers.service_api.app.completion.service_api_ns") - def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app): + def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask): """Test ChatApi.post with wrong app mode.""" from controllers.service_api.app.completion import ChatApi @@ -368,7 +369,7 @@ class TestCompletionControllerLogic: ChatApi().post.__wrapped__(ChatApi(), mock_app_model, mock_end_user) @patch("controllers.service_api.app.completion.AppTaskService") - def test_completion_stop_api_success(self, mock_task_service, app): + def test_completion_stop_api_success(self, mock_task_service, app: Flask): """Test CompletionStopApi.post success.""" from controllers.service_api.app.completion import CompletionStopApi @@ -385,7 +386,7 @@ class TestCompletionControllerLogic: mock_task_service.stop_task.assert_called_once() @patch("controllers.service_api.app.completion.AppTaskService") - def test_chat_stop_api_success(self, mock_task_service, app): + def test_chat_stop_api_success(self, mock_task_service, app: Flask): """Test ChatStopApi.post success.""" from controllers.service_api.app.completion import ChatStopApi diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index dbd06677d8..6dc8f54d42 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -15,10 +15,12 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, NotFound import services @@ -29,11 +31,16 @@ from controllers.service_api.app.conversation import ( ConversationRenameApi, ConversationRenamePayload, ConversationVariableDetailApi, + ConversationVariableInfiniteScrollPaginationResponse, + ConversationVariableResponse, ConversationVariablesApi, ConversationVariablesQuery, ConversationVariableUpdatePayload, ) from controllers.service_api.app.error import NotChatAppError +from fields._value_type_serializer import serialize_value_type +from graphon.variables import StringSegment +from graphon.variables.types import SegmentType from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService from services.errors.conversation import ( @@ -261,6 +268,72 @@ class TestConversationVariableUpdatePayload: assert payload.value == nested +class TestConversationVariableResponseModels: + def test_variable_response_normalizes_value_type_and_timestamps(self): + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "description": "desc", + "created_at": created_at, + "updated_at": created_at, + } + ) + assert response.value_type == "number" + assert response.value == "1" + assert response.created_at == int(created_at.timestamp()) + assert response.updated_at == int(created_at.timestamp()) + + def test_variable_response_normalizes_string_value_type_alias(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER.value, + } + ) + + assert response.value_type == "number" + + def test_variable_response_normalizes_callable_exposed_type(self): + response = ConversationVariableResponse.model_validate( + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SimpleNamespace(exposed_type=lambda: SegmentType.STRING.exposed_type()), + } + ) + + assert response.value_type == "string" + + def test_serialize_value_type_supports_segments_and_mappings(self): + assert serialize_value_type(StringSegment(value="hello")) == "string" + assert serialize_value_type({"value_type": SegmentType.INTEGER}) == "number" + + def test_variable_pagination_response(self): + response = ConversationVariableInfiniteScrollPaginationResponse.model_validate( + { + "limit": 1, + "has_more": False, + "data": [ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": "string", + "value": "bar", + } + ], + } + ) + assert response.limit == 1 + assert response.has_more is False + assert len(response.data) == 1 + assert response.data[0].name == "foo" + + class TestConversationAppModeValidation: """Test app mode validation for conversation endpoints.""" @@ -432,7 +505,7 @@ class TestConversationApiController: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user) - def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_list_last_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: class _BeginStub: def __enter__(self): return SimpleNamespace() @@ -480,7 +553,7 @@ class TestConversationDetailApiController: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") - def test_delete_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "delete", @@ -498,7 +571,7 @@ class TestConversationDetailApiController: class TestConversationRenameApiController: - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "rename", @@ -530,7 +603,7 @@ class TestConversationVariablesApiController: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "get_conversational_variable", @@ -549,9 +622,47 @@ class TestConversationVariablesApiController: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") + def test_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + monkeypatch.setattr( + ConversationService, + "get_conversational_variable", + lambda *_args, **_kwargs: SimpleNamespace( + limit=1, + has_more=False, + data=[ + { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "created_at": created_at, + "updated_at": created_at, + } + ], + ), + ) + + api = ConversationVariablesApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace() + + with app.test_request_context( + "/conversations/1/variables?limit=20", + method="GET", + ): + result = handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") + + assert result["limit"] == 1 + assert result["has_more"] is False + assert result["data"][0]["value_type"] == "number" + assert result["data"][0]["value"] == "1" + assert result["data"][0]["created_at"] == int(created_at.timestamp()) + class TestConversationVariableDetailApiController: - def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_update_type_mismatch(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "update_conversation_variable", @@ -577,7 +688,7 @@ class TestConversationVariableDetailApiController: variable_id="00000000-0000-0000-0000-000000000002", ) - def test_update_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_update_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "update_conversation_variable", @@ -602,3 +713,41 @@ class TestConversationVariableDetailApiController: c_id="00000000-0000-0000-0000-000000000001", variable_id="00000000-0000-0000-0000-000000000002", ) + + def test_update_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + monkeypatch.setattr( + ConversationService, + "update_conversation_variable", + lambda *_args, **_kwargs: { + "id": "550e8400-e29b-41d4-a716-446655440000", + "name": "foo", + "value_type": SegmentType.INTEGER, + "value": 1, + "created_at": created_at, + "updated_at": created_at, + }, + ) + + api = ConversationVariableDetailApi() + handler = _unwrap(api.put) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace() + + with app.test_request_context( + "/conversations/1/variables/2", + method="PUT", + json={"value": 1}, + ): + result = handler( + api, + app_model=app_model, + end_user=end_user, + c_id="00000000-0000-0000-0000-000000000001", + variable_id="00000000-0000-0000-0000-000000000002", + ) + + assert result["id"] == "550e8400-e29b-41d4-a716-446655440000" + assert result["value_type"] == "number" + assert result["value"] == "1" + assert result["created_at"] == int(created_at.timestamp()) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file.py b/api/tests/unit_tests/controllers/service_api/app/test_file.py index 7060bd79df..2615c3edac 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file.py @@ -16,6 +16,7 @@ import uuid from unittest.mock import Mock, patch import pytest +from flask import Flask from controllers.common.errors import ( FilenameNotExistsError, @@ -282,7 +283,7 @@ class TestFileApiPost: assert status == 201 mock_file_svc_cls.return_value.upload_file.assert_called_once() - def test_upload_no_file(self, app, mock_app_model, mock_end_user): + def test_upload_no_file(self, app: Flask, mock_app_model, mock_end_user): """Test NoFileUploadedError when no file in request.""" from controllers.service_api.app.file import FileApi @@ -296,7 +297,7 @@ class TestFileApiPost: with pytest.raises(NoFileUploadedError): _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) - def test_upload_too_many_files(self, app, mock_app_model, mock_end_user): + def test_upload_too_many_files(self, app: Flask, mock_app_model, mock_end_user): """Test TooManyFilesError when multiple files uploaded.""" from io import BytesIO @@ -317,7 +318,7 @@ class TestFileApiPost: with pytest.raises(TooManyFilesError): _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) - def test_upload_no_mimetype(self, app, mock_app_model, mock_end_user): + def test_upload_no_mimetype(self, app: Flask, mock_app_model, mock_end_user): """Test UnsupportedFileTypeError when file has no mimetype.""" from io import BytesIO diff --git a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py new file mode 100644 index 0000000000..510d4a9470 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py @@ -0,0 +1,708 @@ +"""Dedicated tests for HITL behavior exposed through the Service API.""" + +from __future__ import annotations + +import json +import sys +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import ANY, MagicMock, Mock + +import pytest +from flask import Flask + +import services.app_generate_service as ags_module +from controllers.service_api.app.workflow_events import WorkflowEventsApi +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.common import workflow_response_converter +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, + HumanInputRequiredResponse, + WorkflowAppPausedBlockingResponse, + WorkflowPauseStreamResponse, +) +from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper +from core.workflow.human_input_policy import HumanInputSurface +from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType +from graphon.runtime import GraphRuntimeState, VariablePool +from models.account import Account +from models.enums import CreatorUserRole +from models.model import AppMode +from models.workflow import WorkflowRun +from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot +from repositories.entities.workflow_pause import WorkflowPauseEntity +from services.app_generate_service import AppGenerateService +from services.workflow_event_snapshot_service import _build_snapshot_events +from tests.unit_tests.controllers.service_api.conftest import _unwrap + + +class _DummyRateLimit: + @staticmethod + def gen_request_key() -> str: + return "dummy-request-id" + + def __init__(self, client_id: str, max_active_requests: int) -> None: + self.client_id = client_id + self.max_active_requests = max_active_requests + + def enter(self, request_id: str | None = None) -> str: + return request_id or "dummy-request-id" + + def exit(self, request_id: str) -> None: + return None + + def generate(self, generator, request_id: str): + return generator + + +def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run): + workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"] + repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run) + monkeypatch.setattr( + workflow_events_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: repo, + ) + monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object())) + return workflow_events_module + + +def _build_service_api_pause_converter() -> WorkflowResponseConverter: + application_generate_entity = SimpleNamespace( + inputs={}, + files=[], + invoke_from=InvokeFrom.SERVICE_API, + app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), + ) + system_variables = build_system_variables( + user_id="user", + app_id="app-id", + workflow_id="workflow-id", + workflow_execution_id="run-id", + ) + user = MagicMock(spec=Account) + user.id = "account-id" + user.name = "Tester" + user.email = "tester@example.com" + return WorkflowResponseConverter( + application_generate_entity=application_generate_entity, + user=user, + system_variables=system_variables, + ) + + +def _build_advanced_chat_paused_blocking_response() -> AdvancedChatPausedBlockingResponse: + data = AdvancedChatPausedBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + workflow_run_id="run-1", + answer="partial", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + paused_nodes=["node-1"], + reasons=[ + { + "type": PauseReasonType.HUMAN_INPUT_REQUIRED, + "form_id": "form-1", + "expiration_time": 100, + } + ], + status=WorkflowExecutionStatus.PAUSED, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ) + return AdvancedChatPausedBlockingResponse(task_id="t1", data=data) + + +def _build_workflow_paused_blocking_response() -> WorkflowAppPausedBlockingResponse: + return WorkflowAppPausedBlockingResponse( + task_id="t1", + workflow_run_id="r1", + data=WorkflowAppPausedBlockingResponse.Data( + id="r1", + workflow_id="wf-1", + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + error=None, + elapsed_time=0.5, + total_tokens=0, + total_steps=2, + created_at=1, + finished_at=None, + paused_nodes=["node-1"], + reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100}], + ), + ) + + +@dataclass(frozen=True) +class _FakePauseEntity(WorkflowPauseEntity): + pause_id: str + workflow_run_id: str + paused_at_value: datetime + pause_reasons: Sequence[HumanInputRequired] + + @property + def id(self) -> str: + return self.pause_id + + @property + def workflow_execution_id(self) -> str: + return self.workflow_run_id + + def get_state(self) -> bytes: + raise AssertionError("state is not required for snapshot tests") + + @property + def resumed_at(self) -> datetime | None: + return None + + @property + def paused_at(self) -> datetime: + return self.paused_at_value + + def get_pause_reasons(self) -> Sequence[HumanInputRequired]: + return self.pause_reasons + + +def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun: + return WorkflowRun( + id="run-1", + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + type="workflow", + triggered_from="app-run", + version="v1", + graph=None, + inputs=json.dumps({"input": "value"}), + status=status, + outputs=json.dumps({}), + error=None, + elapsed_time=0.0, + total_tokens=0, + total_steps=0, + created_by_role=CreatorUserRole.END_USER, + created_by="user-1", + created_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + + +def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot: + created_at = datetime(2024, 1, 1, tzinfo=UTC) + finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC) + return WorkflowNodeExecutionSnapshot( + execution_id="exec-1", + node_id="node-1", + node_type="human-input", + title="Human Input", + index=1, + status=status.value, + elapsed_time=0.5, + created_at=created_at, + finished_at=finished_at, + iteration_id=None, + loop_id=None, + ) + + +def _build_resumption_context(task_id: str) -> WorkflowResumptionContext: + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant-1", + app_id="app-1", + app_mode=AppMode.WORKFLOW, + workflow_id="workflow-1", + ) + generate_entity = WorkflowAppGenerateEntity( + task_id=task_id, + app_config=app_config, + inputs={}, + files=[], + user_id="user-1", + stream=True, + invoke_from=InvokeFrom.EXPLORE, + call_depth=0, + workflow_execution_id="run-1", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) + runtime_state.register_paused_node("node-1") + runtime_state.outputs = {"result": "value"} + wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity) + return WorkflowResumptionContext( + generate_entity=wrapper, + serialized_graph_runtime_state=runtime_state.dumps(), + ) + + +class TestHitlServiceApi: + # Service API event-stream continuation + def test_workflow_events_continue_on_pause_keeps_stream_open(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + msg_generator.retrieve_events.return_value = ["raw-event"] + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: streamed\n\n" + msg_generator.retrieve_events.assert_called_once_with( + AppMode.WORKFLOW, + "run-1", + terminal_events=[], + ) + workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"]) + + def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open( + self, app: Flask, monkeypatch: pytest.MonkeyPatch + ) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"]) + snapshot_builder = Mock(return_value=["snapshot-events"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context( + "/workflow/run-1/events?user=u1&include_state_snapshot=true&continue_on_pause=true", + method="GET", + ): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: snapshot\n\n" + msg_generator.retrieve_events.assert_not_called() + snapshot_builder.assert_called_once_with( + app_mode=AppMode.WORKFLOW, + workflow_run=workflow_run, + tenant_id="tenant-1", + app_id="app-1", + session_maker=ANY, + human_input_surface=HumanInputSurface.SERVICE_API, + close_on_pause=False, + ) + workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"]) + + def test_advanced_chat_blocking_injects_pause_state_config(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False) + monkeypatch.setattr(ags_module, "RateLimit", _DummyRateLimit) + + workflow = MagicMock() + workflow.created_by = "owner-id" + monkeypatch.setattr(AppGenerateService, "_get_workflow", lambda *args, **kwargs: workflow) + monkeypatch.setattr(ags_module.session_factory, "get_session_maker", lambda: "session-maker") + + generator_instance = MagicMock() + generator_instance.generate.return_value = {"result": "advanced-blocking"} + generator_instance.convert_to_event_stream.side_effect = lambda payload: payload + monkeypatch.setattr(ags_module, "AdvancedChatAppGenerator", lambda: generator_instance) + + app_model = MagicMock() + app_model.mode = AppMode.ADVANCED_CHAT + app_model.id = "app-id" + app_model.tenant_id = "tenant-id" + app_model.max_active_requests = 0 + app_model.is_agent = False + + user = MagicMock() + user.id = "user-id" + + result = AppGenerateService.generate( + app_model=app_model, + user=user, + args={"workflow_id": None, "query": "hi", "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + assert result == {"result": "advanced-blocking"} + call_kwargs = generator_instance.generate.call_args.kwargs + assert call_kwargs["streaming"] is False + assert call_kwargs["pause_state_config"] is not None + assert call_kwargs["pause_state_config"].session_factory == "session-maker" + assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id" + + # Blocking payload contract + def test_advanced_chat_blocking_pause_payload_contract(self) -> None: + from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter + + response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response( + _build_advanced_chat_paused_blocking_response() + ) + + assert response["event"] == "workflow_paused" + assert response["workflow_run_id"] == "run-1" + assert response["answer"] == "partial" + assert response["data"]["reasons"][0]["type"] == PauseReasonType.HUMAN_INPUT_REQUIRED + assert response["data"]["reasons"][0]["expiration_time"] == 100 + assert "human_input_forms" not in response["data"] + + def test_workflow_blocking_pause_payload_contract(self) -> None: + from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter + + response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response( + _build_workflow_paused_blocking_response() + ) + + assert response["workflow_run_id"] == "r1" + assert response["data"]["status"] == WorkflowExecutionStatus.PAUSED + assert response["data"]["paused_nodes"] == ["node-1"] + assert response["data"]["reasons"] == [ + {"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 100} + ] + assert "human_input_forms" not in response["data"] + + def test_advanced_chat_blocking_pipeline_pause_payload_contract(self) -> None: + from core.app.app_config.entities import AppAdditionalFeatures + from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline + from models.enums import MessageStatus + from models.model import EndUser + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}), + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + conversation=SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT), + message=SimpleNamespace( + id="message-id", + query="hello", + created_at=datetime.utcnow(), + status=MessageStatus.NORMAL, + answer="", + ), + user=EndUser(tenant_id="tenant", type="session", name="tester", session_id="session"), + stream=False, + dialogue_count=1, + draft_var_saver_factory=lambda **kwargs: None, + ) + pipeline._task_state.answer = "partial answer" + pipeline._workflow_run_id = "run-id" + + def _gen(): + yield HumanInputRequiredResponse( + task_id="task", + workflow_run_id="run-id", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Approval", + form_content="Need approval", + inputs=[], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + form_token="token-1", + resolved_default_values={}, + expiration_time=123, + ), + ) + yield WorkflowPauseStreamResponse( + task_id="task", + workflow_run_id="run-id", + data=WorkflowPauseStreamResponse.Data( + workflow_run_id="run-id", + paused_nodes=["node-1"], + outputs={}, + reasons=[ + { + "type": PauseReasonType.HUMAN_INPUT_REQUIRED, + "form_id": "form-1", + "node_id": "node-1", + "expiration_time": 123, + }, + ], + status="paused", + created_at=1, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert isinstance(response, AdvancedChatPausedBlockingResponse) + assert response.data.answer == "partial answer" + assert response.data.workflow_run_id == "run-id" + assert response.data.reasons[0]["form_id"] == "form-1" + assert response.data.reasons[0]["expiration_time"] == 123 + + def test_workflow_blocking_pipeline_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None: + from core.app.apps.workflow import generate_task_pipeline as workflow_pipeline_module + from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + trace_manager=None, + workflow_execution_id="run-id", + extras={}, + call_depth=0, + ) + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}), + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + user=SimpleNamespace(id="user", session_id="session"), + stream=False, + draft_var_saver_factory=lambda **kwargs: None, + ) + monkeypatch.setattr(workflow_pipeline_module.time, "time", lambda: 1700000000) + + def _gen(): + yield HumanInputRequiredResponse( + task_id="task", + workflow_run_id="run", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Human Input", + form_content="content", + expiration_time=1, + ), + ) + yield WorkflowPauseStreamResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowPauseStreamResponse.Data( + workflow_run_id="run", + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + paused_nodes=["node-1"], + reasons=[{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}], + created_at=1, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert isinstance(response, WorkflowAppPausedBlockingResponse) + assert response.data.status == WorkflowExecutionStatus.PAUSED + assert response.data.paused_nodes == ["node-1"] + assert response.data.reasons == [{"TYPE": "human_input_required", "form_id": "form-1", "expiration_time": 1}] + + def test_service_api_pause_event_serializes_hitl_reason(self, monkeypatch: pytest.MonkeyPatch) -> None: + converter = _build_service_api_pause_converter() + converter.workflow_start_to_stream_response( + task_id="task", + workflow_run_id="run-id", + workflow_id="workflow-id", + reason=WorkflowStartReason.INITIAL, + ) + + expiration_time = datetime(2024, 1, 1, tzinfo=UTC) + + class _FakeSession: + def execute(self, _stmt): + return [("form-1", expiration_time, '{"display_in_ui": true}')] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession()) + monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + workflow_response_converter, + "load_form_tokens_by_form_id", + lambda form_ids, session=None, surface=None: {"form-1": "token"}, + ) + + reason = HumanInputRequired( + form_id="form-1", + form_content="Rendered", + inputs=[ + FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), + ], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + node_id="node-id", + node_title="Human Step", + form_token="token", + ) + queue_event = QueueWorkflowPausedEvent( + reasons=[reason], + outputs={"answer": "value"}, + paused_nodes=["node-id"], + ) + + runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0) + responses = converter.workflow_pause_to_stream_response( + event=queue_event, + task_id="task", + graph_runtime_state=runtime_state, + ) + + assert isinstance(responses[-1], WorkflowPauseStreamResponse) + pause_resp = responses[-1] + assert pause_resp.workflow_run_id == "run-id" + assert pause_resp.data.paused_nodes == ["node-id"] + assert pause_resp.data.outputs == {} + assert pause_resp.data.reasons[0]["TYPE"] == "human_input_required" + assert pause_resp.data.reasons[0]["form_id"] == "form-1" + assert pause_resp.data.reasons[0]["form_token"] == "token" + assert pause_resp.data.reasons[0]["expiration_time"] == int(expiration_time.timestamp()) + + assert isinstance(responses[0], HumanInputRequiredResponse) + hi_resp = responses[0] + assert hi_resp.data.form_id == "form-1" + assert hi_resp.data.node_id == "node-id" + assert hi_resp.data.node_title == "Human Step" + assert hi_resp.data.inputs[0].output_variable_name == "field" + assert hi_resp.data.actions[0].id == "approve" + assert hi_resp.data.display_in_ui is True + assert hi_resp.data.form_token == "token" + assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) + + # Snapshot payload contract + def test_snapshot_events_include_pause_payload_contract(self, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED) + snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED) + resumption_context = _build_resumption_context("task-ctx") + monkeypatch.setattr( + "services.workflow_event_snapshot_service.load_form_tokens_by_form_id", + lambda form_ids, session=None, surface=None: {"form-1": "wtok"}, + ) + + class _SessionContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return False + + def session_maker() -> _SessionContext: + return _SessionContext( + SimpleNamespace( + execute=lambda _stmt: [("form-1", datetime(2024, 1, 1, tzinfo=UTC), '{"display_in_ui": true}')], + ) + ) + + pause_entity = _FakePauseEntity( + pause_id="pause-1", + workflow_run_id="run-1", + paused_at_value=datetime(2024, 1, 1, tzinfo=UTC), + pause_reasons=[ + HumanInputRequired( + form_id="form-1", + form_content="content", + node_id="node-1", + node_title="Human Input", + form_token="wtok", + ) + ], + ) + + events = _build_snapshot_events( + workflow_run=workflow_run, + node_snapshots=[snapshot], + task_id="task-ctx", + message_context=None, + pause_entity=pause_entity, + resumption_context=resumption_context, + session_maker=session_maker, + ) + + assert [event["event"] for event in events] == [ + "workflow_started", + "node_started", + "node_finished", + "human_input_required", + "workflow_paused", + ] + assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value + assert events[3]["data"]["form_token"] == "wtok" + assert events[3]["data"]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp()) + pause_data = events[-1]["data"] + assert pause_data["paused_nodes"] == ["node-1"] + assert pause_data["outputs"] == {"result": "value"} + assert pause_data["reasons"][0]["TYPE"] == "human_input_required" + assert pause_data["reasons"][0]["form_token"] == "wtok" + assert pause_data["reasons"][0]["expiration_time"] == int(datetime(2024, 1, 1, tzinfo=UTC).timestamp()) + assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value + assert pause_data["created_at"] == int(workflow_run.created_at.timestamp()) + assert pause_data["elapsed_time"] == workflow_run.elapsed_time + assert pause_data["total_tokens"] == workflow_run.total_tokens + assert pause_data["total_steps"] == workflow_run.total_steps diff --git a/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py b/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py new file mode 100644 index 0000000000..531f722ceb --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py @@ -0,0 +1,184 @@ +"""Unit tests for Service API human input form endpoints.""" + +from __future__ import annotations + +import json +import sys +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from werkzeug.exceptions import NotFound + +from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi +from models.human_input import RecipientType +from tests.unit_tests.controllers.service_api.conftest import _unwrap + + +class TestWorkflowHumanInputFormApi: + def test_get_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + definition = SimpleNamespace( + model_dump=lambda: { + "rendered_content": "Rendered form content", + "inputs": [{"output_variable_name": "name"}], + "default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}}, + "user_actions": [{"id": "approve", "title": "Approve"}], + } + ) + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=RecipientType.STANDALONE_WEB_APP, + expiration_time=datetime(2099, 1, 1, tzinfo=UTC), + get_definition=lambda: definition, + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + with app.test_request_context("/form/human_input/token-1", method="GET"): + response = handler(api, app_model=app_model, form_token="token-1") + + payload = json.loads(response.get_data(as_text=True)) + assert payload == { + "form_content": "Rendered form content", + "inputs": [{"output_variable_name": "name"}], + "resolved_default_values": {"name": "Alice", "age": "30", "meta": '{"k": "v"}'}, + "user_actions": [{"id": "approve", "title": "Approve"}], + "expiration_time": int(form.expiration_time.timestamp()), + } + service_mock.get_form_by_token.assert_called_once_with("token-1") + service_mock.ensure_form_active.assert_called_once_with(form) + + def test_get_form_not_in_app(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace( + app_id="another-app", + tenant_id="tenant-1", + expiration_time=datetime(2099, 1, 1, tzinfo=UTC), + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + with app.test_request_context("/form/human_input/token-1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, form_token="token-1") + + @pytest.mark.parametrize( + "recipient_type", + [ + RecipientType.CONSOLE, + RecipientType.BACKSTAGE, + RecipientType.EMAIL_MEMBER, + RecipientType.EMAIL_EXTERNAL, + ], + ) + def test_get_rejects_non_service_api_recipient_types( + self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType + ) -> None: + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=recipient_type, + expiration_time=datetime(2099, 1, 1, tzinfo=UTC), + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + + with app.test_request_context("/form/human_input/token-1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, form_token="token-1") + + service_mock.ensure_form_active.assert_not_called() + + def test_post_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=RecipientType.STANDALONE_WEB_APP, + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context( + "/form/human_input/token-1", + method="POST", + json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"}, + ): + response, status = handler(api, app_model=app_model, end_user=end_user, form_token="token-1") + + assert response == {} + assert status == 200 + service_mock.submit_form_by_token.assert_called_once_with( + recipient_type=RecipientType.STANDALONE_WEB_APP, + form_token="token-1", + selected_action_id="approve", + form_data={"name": "Alice"}, + submission_end_user_id="end-user-1", + ) + + @pytest.mark.parametrize( + "recipient_type", + [ + RecipientType.CONSOLE, + RecipientType.BACKSTAGE, + RecipientType.EMAIL_MEMBER, + RecipientType.EMAIL_EXTERNAL, + ], + ) + def test_post_rejects_non_service_api_recipient_types( + self, app, monkeypatch: pytest.MonkeyPatch, recipient_type: RecipientType + ) -> None: + form = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + recipient_type=recipient_type, + ) + service_mock = Mock() + service_mock.get_form_by_token.return_value = form + workflow_module = sys.modules["controllers.service_api.app.human_input_form"] + monkeypatch.setattr(workflow_module, "HumanInputService", lambda _engine: service_mock) + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + + api = WorkflowHumanInputFormApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context( + "/form/human_input/token-1", + method="POST", + json={"inputs": {"name": "Alice"}, "action": "approve", "user": "external-1"}, + ): + with pytest.raises(NotFound): + handler(api, app_model=app_model, end_user=end_user, form_token="token-1") + + service_mock.submit_form_by_token.assert_not_called() diff --git a/api/tests/unit_tests/controllers/service_api/app/test_message.py b/api/tests/unit_tests/controllers/service_api/app/test_message.py index c2b8aed1ae..2bc9771862 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_message.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_message.py @@ -19,6 +19,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from controllers.service_api.app.error import NotChatAppError @@ -390,7 +391,7 @@ class TestMessageListApi: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user) - def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_conversation_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "pagination_by_first_id", @@ -409,7 +410,7 @@ class TestMessageListApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user) - def test_first_message_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_first_message_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "pagination_by_first_id", @@ -430,7 +431,7 @@ class TestMessageListApi: class TestMessageFeedbackApi: - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "create_feedback", @@ -452,7 +453,7 @@ class TestMessageFeedbackApi: class TestAppGetFeedbacksApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"]) api = AppGetFeedbacksApi() @@ -476,7 +477,7 @@ class TestMessageSuggestedApi: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", @@ -492,7 +493,7 @@ class TestMessageSuggestedApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_disabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_disabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", @@ -508,7 +509,7 @@ class TestMessageSuggestedApi: with pytest.raises(BadRequest): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_internal_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_internal_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", @@ -524,7 +525,7 @@ class TestMessageSuggestedApi: with pytest.raises(InternalServerError): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index cfa21bf2dd..7115ea1e12 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -15,11 +15,12 @@ Focus on: import sys import uuid +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.enums import WorkflowExecutionStatus +from flask import Flask from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -36,6 +37,7 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from graphon.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError @@ -43,6 +45,22 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService +def _make_mock_workflow_run(run_id: str = "run-1"): + run = Mock() + run.id = run_id + run.workflow_id = "wf-1" + run.status = WorkflowExecutionStatus.SUCCEEDED + run.inputs = {"input": "value"} + run.outputs_dict = {"output": "value"} + run.error = None + run.total_steps = 1 + run.total_tokens = 10 + run.created_at = datetime(2026, 1, 1, tzinfo=UTC) + run.finished_at = datetime(2026, 1, 1, tzinfo=UTC) + run.elapsed_time = 0.1 + return run + + class TestWorkflowRunPayload: """Test suite for WorkflowRunPayload Pydantic model.""" @@ -349,7 +367,7 @@ class TestWorkflowRunRepository: class TestWorkflowRunDetailApi: - def test_not_workflow_app(self, app) -> None: + def test_not_workflow_app(self, app: Flask) -> None: api = WorkflowRunDetailApi() handler = _unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) @@ -359,7 +377,7 @@ class TestWorkflowRunDetailApi: handler(api, app_model=app_model, workflow_run_id="run") def test_success(self, monkeypatch: pytest.MonkeyPatch) -> None: - run = SimpleNamespace(id="run") + run = _make_mock_workflow_run(run_id="run") repo = SimpleNamespace(get_workflow_run_by_id=lambda **_kwargs: run) workflow_module = sys.modules["controllers.service_api.app.workflow"] monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) @@ -373,11 +391,14 @@ class TestWorkflowRunDetailApi: handler = _unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1") - assert handler(api, app_model=app_model, workflow_run_id="run") == run + result = handler(api, app_model=app_model, workflow_run_id="run") + assert result["id"] == "run" + assert result["workflow_id"] == "wf-1" + assert result["status"] == "succeeded" class TestWorkflowRunApi: - def test_not_workflow_app(self, app) -> None: + def test_not_workflow_app(self, app: Flask) -> None: api = WorkflowRunApi() handler = _unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) @@ -387,7 +408,7 @@ class TestWorkflowRunApi: with pytest.raises(NotWorkflowAppError): handler(api, app_model=app_model, end_user=end_user) - def test_rate_limit(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_rate_limit(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AppGenerateService, "generate", @@ -405,7 +426,7 @@ class TestWorkflowRunApi: class TestWorkflowRunByIdApi: - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AppGenerateService, "generate", @@ -421,7 +442,7 @@ class TestWorkflowRunByIdApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, workflow_id="w1") - def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_draft_workflow(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AppGenerateService, "generate", @@ -439,7 +460,7 @@ class TestWorkflowRunByIdApi: class TestWorkflowTaskStopApi: - def test_wrong_mode(self, app) -> None: + def test_wrong_mode(self, app: Flask) -> None: api = WorkflowTaskStopApi() handler = _unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) @@ -449,7 +470,7 @@ class TestWorkflowTaskStopApi: with pytest.raises(NotWorkflowAppError): handler(api, app_model=app_model, end_user=end_user, task_id="t1") - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: stop_mock = Mock() send_mock = Mock() monkeypatch.setattr(AppQueueManager, "set_stop_flag_no_user_check", stop_mock) @@ -469,7 +490,7 @@ class TestWorkflowTaskStopApi: class TestWorkflowAppLogApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: class _BeginStub: def __enter__(self): return SimpleNamespace() @@ -490,7 +511,7 @@ class TestWorkflowAppLogApi: monkeypatch.setattr( WorkflowAppService, "get_paginate_workflow_app_logs", - lambda *_args, **_kwargs: {"items": [], "total": 0}, + lambda *_args, **_kwargs: {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}, ) api = WorkflowAppLogApi() @@ -500,7 +521,7 @@ class TestWorkflowAppLogApi: with app.test_request_context("/workflows/logs", method="GET"): response = handler(api, app_model=app_model) - assert response == {"items": [], "total": 0} + assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} # ============================================================================= @@ -527,9 +548,8 @@ def mock_workflow_app(): class TestWorkflowRunDetailApiGet: """Test suite for WorkflowRunDetailApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``) - and ``@service_api_ns.marshal_with``. We call the unwrapped method - directly; ``marshal_with`` is a no-op when calling directly. + ``get`` is wrapped by ``@validate_app_token`` (preserves ``__wrapped__``), + and we call the unwrapped method directly in tests. """ @patch("controllers.service_api.app.workflow.DifyAPIRepositoryFactory") @@ -538,13 +558,11 @@ class TestWorkflowRunDetailApiGet: self, mock_db, mock_repo_factory, - app, + app: Flask, mock_workflow_app, ): """Test successful workflow run detail retrieval.""" - mock_run = Mock() - mock_run.id = "run-1" - mock_run.status = "succeeded" + mock_run = _make_mock_workflow_run(run_id="run-1") mock_repo = Mock() mock_repo.get_workflow_run_by_id.return_value = mock_run mock_repo_factory.create_api_workflow_run_repository.return_value = mock_repo @@ -558,10 +576,11 @@ class TestWorkflowRunDetailApiGet: api = WorkflowRunDetailApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id) - assert result == mock_run + assert result["id"] == mock_run.id + assert result["status"] == "succeeded" @patch("controllers.service_api.app.workflow.db") - def test_get_workflow_run_wrong_app_mode(self, mock_db, app): + def test_get_workflow_run_wrong_app_mode(self, mock_db, app: Flask): """Test NotWorkflowAppError when app mode is not workflow or advanced_chat.""" from controllers.service_api.app.workflow import WorkflowRunDetailApi @@ -586,7 +605,7 @@ class TestWorkflowTaskStopApiPost: self, mock_queue_mgr, mock_graph_mgr, - app, + app: Flask, mock_workflow_app, ): """Test successful workflow task stop.""" @@ -606,7 +625,7 @@ class TestWorkflowTaskStopApiPost: mock_graph_mgr.assert_called_once() mock_graph_mgr.return_value.send_stop_command.assert_called_once_with("task-1") - def test_stop_workflow_task_wrong_app_mode(self, app): + def test_stop_workflow_task_wrong_app_mode(self, app: Flask): """Test NotWorkflowAppError when app mode is not workflow.""" from controllers.service_api.app.workflow import WorkflowTaskStopApi @@ -622,8 +641,7 @@ class TestWorkflowTaskStopApiPost: class TestWorkflowAppLogApiGet: """Test suite for WorkflowAppLogApi.get() endpoint. - ``get`` is wrapped by ``@validate_app_token`` and - ``@service_api_ns.marshal_with``. + ``get`` is wrapped by ``@validate_app_token``. """ @patch("controllers.service_api.app.workflow.WorkflowAppService") @@ -632,11 +650,15 @@ class TestWorkflowAppLogApiGet: self, mock_db, mock_wf_svc_cls, - app, + app: Flask, mock_workflow_app, ): """Test successful workflow log retrieval.""" mock_pagination = Mock() + mock_pagination.page = 1 + mock_pagination.limit = 20 + mock_pagination.total = 0 + mock_pagination.has_more = False mock_pagination.data = [] mock_svc_instance = Mock() mock_svc_instance.get_paginate_workflow_app_logs.return_value = mock_pagination @@ -661,4 +683,4 @@ class TestWorkflowAppLogApiGet: api = WorkflowAppLogApi() result = _unwrap(api.get)(api, app_model=mock_workflow_app) - assert result == mock_pagination + assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py new file mode 100644 index 0000000000..b3edc2ecd8 --- /dev/null +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py @@ -0,0 +1,167 @@ +"""Unit tests for Service API workflow event stream endpoints.""" + +from __future__ import annotations + +import json +import sys +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.service_api.app.error import NotWorkflowAppError +from controllers.service_api.app.workflow_events import WorkflowEventsApi +from models.enums import CreatorUserRole +from models.model import AppMode +from tests.unit_tests.controllers.service_api.conftest import _unwrap + + +def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run): + workflow_events_module = sys.modules["controllers.service_api.app.workflow_events"] + repo = SimpleNamespace(get_workflow_run_by_id_and_tenant_id=lambda **_kwargs: workflow_run) + monkeypatch.setattr( + workflow_events_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: repo, + ) + monkeypatch.setattr(workflow_events_module, "db", SimpleNamespace(engine=object())) + return workflow_events_module + + +class TestWorkflowEventsApi: + def test_wrong_app_mode(self, app) -> None: + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(mode=AppMode.CHAT.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + with pytest.raises(NotWorkflowAppError): + handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + def test_workflow_run_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + _mock_repo_for_run(monkeypatch, workflow_run=None) + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + def test_workflow_run_permission_denied(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="another-user", + finished_at=None, + ) + _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + with pytest.raises(NotFound): + handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + def test_finished_run_returns_sse(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=datetime(2099, 1, 1, tzinfo=UTC), + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + monkeypatch.setattr( + workflow_events_module.WorkflowResponseConverter, + "workflow_run_result_to_finish_response", + lambda **_kwargs: SimpleNamespace( + model_dump=lambda mode="json": {"task_id": "run-1", "status": "succeeded"}, + event=SimpleNamespace(value="workflow_finished"), + ), + ) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.mimetype == "text/event-stream" + body = response.get_data(as_text=True).strip() + assert body.startswith("data: ") + payload = json.loads(body[len("data: ") :]) + assert payload["task_id"] == "run-1" + assert payload["event"] == "workflow_finished" + + def test_running_run_streams_events(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + msg_generator.retrieve_events.return_value = ["raw-event"] + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: streamed\n\n"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: streamed\n\n" + msg_generator.retrieve_events.assert_called_once_with( + AppMode.WORKFLOW, + "run-1", + terminal_events=None, + ) + workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"]) + + def test_running_run_with_snapshot(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="end-user-1", + finished_at=None, + ) + workflow_events_module = _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) + msg_generator = Mock() + workflow_generator = Mock() + workflow_generator.convert_to_event_stream.return_value = iter(["data: snapshot\n\n"]) + snapshot_builder = Mock(return_value=["snapshot-events"]) + monkeypatch.setattr(workflow_events_module, "MessageGenerator", lambda: msg_generator) + monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) + monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder) + + api = WorkflowEventsApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + end_user = SimpleNamespace(id="end-user-1") + + with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"): + response = handler(api, app_model=app_model, end_user=end_user, task_id="run-1") + + assert response.get_data(as_text=True) == "data: snapshot\n\n" + msg_generator.retrieve_events.assert_not_called() + snapshot_builder.assert_called_once() + workflow_generator.convert_to_event_stream.assert_called_once_with(["snapshot-events"]) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index 4b8e3a738c..eda270258d 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,8 +1,7 @@ from types import SimpleNamespace -from graphon.enums import WorkflowExecutionStatus - from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField +from graphon.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index eddba5a517..8c89812cb4 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -15,7 +15,10 @@ from flask import Flask from core.rag.index_processor.constant.index_type import IndexStructureType from models.account import TenantStatus from models.model import App, AppMode, EndUser -from tests.unit_tests.conftest import setup_mock_tenant_account_query +from tests.unit_tests.conftest import ( + setup_mock_dataset_owner_execute_result, + setup_mock_tenant_owner_execute_result, +) @pytest.fixture @@ -123,9 +126,7 @@ class AuthenticationMocker: mock_db.session.get.side_effect = [mock_app, mock_tenant] if mock_account: - mock_ta = Mock() - mock_ta.account_id = mock_account.id - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) @staticmethod def setup_dataset_auth(mock_db, mock_tenant, mock_account): @@ -133,8 +134,7 @@ class AuthenticationMocker: mock_ta = Mock() mock_ta.account_id = mock_account.id - mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) - + setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta) mock_db.session.get.return_value = mock_account diff --git a/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py index f33c482d04..362af883ed 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py @@ -23,6 +23,7 @@ from datetime import UTC, datetime from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden, NotFound @@ -373,7 +374,7 @@ class TestDatasourcePluginsApiGet: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") - def test_get_plugins_success(self, mock_svc_cls, mock_db, app): + def test_get_plugins_success(self, mock_svc_cls, mock_db, app: Flask): """Test successful retrieval of datasource plugins.""" tenant_id = str(uuid.uuid4()) dataset_id = str(uuid.uuid4()) @@ -396,7 +397,7 @@ class TestDatasourcePluginsApiGet: ) @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_get_plugins_not_found(self, mock_db, app): + def test_get_plugins_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" mock_db.session.scalar.return_value = None @@ -407,7 +408,7 @@ class TestDatasourcePluginsApiGet: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") - def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app): + def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app: Flask): """Test empty plugin list.""" mock_db.session.scalar.return_value = Mock() mock_svc_instance = Mock() @@ -439,7 +440,7 @@ class TestDatasourceNodeRunApiPost: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns") - def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app): + def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app: Flask): """Test successful datasource node run.""" tenant_id = str(uuid.uuid4()) dataset_id = str(uuid.uuid4()) @@ -473,7 +474,7 @@ class TestDatasourceNodeRunApiPost: mock_svc_instance.run_datasource_workflow_node.assert_called_once() @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_post_not_found(self, mock_db, app): + def test_post_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" mock_db.session.scalar.return_value = None @@ -488,7 +489,7 @@ class TestDatasourceNodeRunApiPost: ) @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns") - def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app): + def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app: Flask): """Test AssertionError when current_user is not an Account instance.""" mock_db.session.scalar.return_value = Mock() mock_ns.payload = { @@ -549,7 +550,7 @@ class TestPipelineRunApiPost: mock_gen_svc.generate.assert_called_once() @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_post_not_found(self, mock_db, app): + def test_post_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" mock_db.session.scalar.return_value = None @@ -561,7 +562,7 @@ class TestPipelineRunApiPost: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user", new="not_account") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns") - def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app): + def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app: Flask): """Test Forbidden when current_user is not an Account.""" mock_db.session.scalar.return_value = Mock() mock_ns.payload = { @@ -585,7 +586,7 @@ class TestFileUploadApiPost: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.FileService") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app): + def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app: Flask): """Test successful file upload.""" mock_current_user.__bool__ = Mock(return_value=True) @@ -621,7 +622,7 @@ class TestFileUploadApiPost: assert response["name"] == "doc.pdf" assert response["extension"] == "pdf" - def test_upload_no_file(self, app): + def test_upload_no_file(self, app: Flask): """Test error when no file is uploaded.""" with app.test_request_context( "/datasets/pipeline/file-upload", diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index e9c3e6d376..fe8fc02548 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -18,6 +18,7 @@ import uuid from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.service_api.dataset.segment import ( @@ -782,7 +783,7 @@ class TestSegmentApiGet: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -893,7 +894,7 @@ class TestSegmentApiPost: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -946,7 +947,7 @@ class TestSegmentApiPost: mock_db, mock_account_fn, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -989,7 +990,7 @@ class TestSegmentApiPost: mock_db, mock_account_fn, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1041,7 +1042,7 @@ class TestDatasetSegmentApiDelete: mock_doc_svc, mock_dataset_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1086,7 +1087,7 @@ class TestDatasetSegmentApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1128,7 +1129,7 @@ class TestDatasetSegmentApiDelete: mock_account_fn, mock_doc_svc, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1162,7 +1163,7 @@ class TestDatasetSegmentApiDelete: mock_account_fn, mock_dataset_svc, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1232,7 +1233,7 @@ class TestDatasetSegmentApiUpdate: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1282,7 +1283,7 @@ class TestDatasetSegmentApiUpdate: mock_account_fn, mock_dataset_svc, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1322,7 +1323,7 @@ class TestDatasetSegmentApiUpdate: mock_dataset_svc, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1374,7 +1375,7 @@ class TestDatasetSegmentApiGetSingle: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1421,7 +1422,7 @@ class TestDatasetSegmentApiGetSingle: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1460,7 +1461,7 @@ class TestDatasetSegmentApiGetSingle: self, mock_db, mock_account_fn, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1491,7 +1492,7 @@ class TestDatasetSegmentApiGetSingle: mock_account_fn, mock_dataset_svc, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1526,7 +1527,7 @@ class TestDatasetSegmentApiGetSingle: mock_dataset_svc, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1570,7 +1571,7 @@ class TestChildChunkApiGet: mock_doc_svc, mock_seg_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1609,7 +1610,7 @@ class TestChildChunkApiGet: self, mock_db, mock_account_fn, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1638,7 +1639,7 @@ class TestChildChunkApiGet: mock_db, mock_account_fn, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1670,7 +1671,7 @@ class TestChildChunkApiGet: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1729,7 +1730,7 @@ class TestChildChunkApiPost: mock_doc_svc, mock_seg_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1771,7 +1772,7 @@ class TestChildChunkApiPost: mock_feature_svc, mock_db, mock_account_fn, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1809,7 +1810,7 @@ class TestChildChunkApiPost: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1863,7 +1864,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1913,7 +1914,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1954,7 +1955,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1994,7 +1995,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 12d5e7345d..230c51161f 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -22,6 +22,9 @@ import pytest from werkzeug.exceptions import Forbidden, NotFound from controllers.service_api.dataset.document import ( + DeprecatedDocumentAddByTextApi, + DeprecatedDocumentUpdateByFileApi, + DeprecatedDocumentUpdateByTextApi, DocumentAddByFileApi, DocumentAddByTextApi, DocumentApi, @@ -30,7 +33,6 @@ from controllers.service_api.dataset.document import ( DocumentListQuery, DocumentTextCreatePayload, DocumentTextUpdate, - DocumentUpdateByFileApi, DocumentUpdateByTextApi, InvalidMetadataError, ) @@ -699,8 +701,8 @@ class TestDocumentApiDelete: ``delete`` is wrapped by ``@cloud_edition_billing_rate_limit_check`` which internally calls ``validate_and_get_api_token``. To bypass the decorator we call the original function via ``__wrapped__`` (preserved by - ``functools.wraps``). ``delete`` queries the dataset via - ``db.session.query(Dataset)`` directly, so we patch ``db`` at the + ``functools.wraps``). ``delete`` loads the dataset via + ``db.session.scalar(select(Dataset)...)``, so we patch ``db`` at the controller module. """ @@ -1005,7 +1007,7 @@ class TestDocumentAddByTextApi: # Act with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={ "name": "Test Document", @@ -1037,7 +1039,7 @@ class TestDocumentAddByTextApi: # Act & Assert with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={"name": "Test Document", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, @@ -1066,7 +1068,7 @@ class TestDocumentAddByTextApi: # Act & Assert with app.test_request_context( - f"/datasets/{mock_dataset.id}/document/create_by_text", + f"/datasets/{mock_dataset.id}/document/create-by-text", method="POST", json={"name": "Test Document", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, @@ -1093,9 +1095,28 @@ class TestArchivedDocumentImmutableError: assert error.code == 403 +class TestDocumentRouteDeprecation: + """Test that legacy document routes stay marked deprecated.""" + + def test_create_by_text_legacy_alias_is_deprecated(self): + """Ensure only the legacy create-by-text alias is marked deprecated.""" + assert DeprecatedDocumentAddByTextApi.post.__apidoc__["deprecated"] is True + assert DocumentAddByTextApi.post.__apidoc__.get("deprecated") is not True + + def test_update_by_text_legacy_alias_is_deprecated(self): + """Ensure only the legacy update-by-text alias is marked deprecated.""" + assert DeprecatedDocumentUpdateByTextApi.post.__apidoc__["deprecated"] is True + assert DocumentUpdateByTextApi.post.__apidoc__.get("deprecated") is not True + + def test_update_by_file_legacy_aliases_are_deprecated(self): + """Ensure only the legacy file-update aliases are marked deprecated.""" + assert DeprecatedDocumentUpdateByFileApi.post.__apidoc__["deprecated"] is True + assert DocumentApi.patch.__apidoc__.get("deprecated") is not True + + # ============================================================================= # Endpoint tests for DocumentUpdateByTextApi, DocumentAddByFileApi, -# DocumentUpdateByFileApi. +# and the canonical/deprecated document file update routes. # # These controllers use ``@cloud_edition_billing_resource_check`` (does NOT # preserve ``__wrapped__``) and ``@cloud_edition_billing_rate_limit_check`` @@ -1162,7 +1183,7 @@ class TestDocumentUpdateByTextApiPost: doc_id = str(uuid.uuid4()) with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text", + f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text", method="POST", json={"name": "Updated Doc", "text": "New content"}, headers={"Authorization": "Bearer test_token"}, @@ -1195,7 +1216,7 @@ class TestDocumentUpdateByTextApiPost: doc_id = str(uuid.uuid4()) with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_text", + f"/datasets/{mock_dataset.id}/documents/{doc_id}/update-by-text", method="POST", json={"name": "Doc", "text": "Content"}, headers={"Authorization": "Bearer test_token"}, @@ -1343,13 +1364,52 @@ class TestDocumentAddByFileApiPost: api.post(tenant_id=mock_tenant.id, dataset_id=mock_dataset.id) -class TestDocumentUpdateByFileApiPost: - """Test suite for DocumentUpdateByFileApi.post() endpoint. +class TestDocumentUpdateByFileApiPatch: + """Test suite for the canonical document file update endpoint. - ``post`` is wrapped by ``@cloud_edition_billing_resource_check`` and + ``patch`` is wrapped by ``@cloud_edition_billing_resource_check`` and ``@cloud_edition_billing_rate_limit_check``. """ + @pytest.mark.parametrize("route_name", ["update_by_file", "update-by-file"]) + @patch("controllers.service_api.dataset.document._update_document_by_file") + @patch("controllers.service_api.wraps.FeatureService") + @patch("controllers.service_api.wraps.validate_and_get_api_token") + def test_update_by_file_deprecated_aliases_delegate_to_shared_handler( + self, + mock_validate_token, + mock_feature_svc, + mock_update_document_by_file, + route_name, + app, + mock_tenant, + mock_dataset, + ): + """Test legacy POST aliases still dispatch while marked deprecated.""" + _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) + mock_update_document_by_file.return_value = ({"document": {"id": "doc-1"}, "batch": "batch-1"}, 200) + + doc_id = str(uuid.uuid4()) + with app.test_request_context( + f"/datasets/{mock_dataset.id}/documents/{doc_id}/{route_name}", + method="POST", + headers={"Authorization": "Bearer test_token"}, + ): + api = DeprecatedDocumentUpdateByFileApi() + response, status = api.post( + tenant_id=mock_tenant.id, + dataset_id=mock_dataset.id, + document_id=doc_id, + ) + + assert status == 200 + assert response["batch"] == "batch-1" + mock_update_document_by_file.assert_called_once_with( + tenant_id=mock_tenant.id, + dataset_id=mock_dataset.id, + document_id=doc_id, + ) + @patch("controllers.service_api.dataset.document.db") @patch("controllers.service_api.wraps.FeatureService") @patch("controllers.service_api.wraps.validate_and_get_api_token") @@ -1371,15 +1431,15 @@ class TestDocumentUpdateByFileApiPost: doc_id = str(uuid.uuid4()) data = {"file": (BytesIO(b"content"), "test.pdf", "application/pdf")} with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_file", - method="POST", + f"/datasets/{mock_dataset.id}/documents/{doc_id}", + method="PATCH", content_type="multipart/form-data", data=data, headers={"Authorization": "Bearer test_token"}, ): - api = DocumentUpdateByFileApi() + api = DocumentApi() with pytest.raises(ValueError, match="Dataset does not exist"): - api.post( + api.patch( tenant_id=mock_tenant.id, dataset_id=mock_dataset.id, document_id=doc_id, @@ -1407,15 +1467,15 @@ class TestDocumentUpdateByFileApiPost: doc_id = str(uuid.uuid4()) data = {"file": (BytesIO(b"content"), "test.pdf", "application/pdf")} with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_file", - method="POST", + f"/datasets/{mock_dataset.id}/documents/{doc_id}", + method="PATCH", content_type="multipart/form-data", data=data, headers={"Authorization": "Bearer test_token"}, ): - api = DocumentUpdateByFileApi() + api = DocumentApi() with pytest.raises(ValueError, match="External datasets"): - api.post( + api.patch( tenant_id=mock_tenant.id, dataset_id=mock_dataset.id, document_id=doc_id, @@ -1466,14 +1526,14 @@ class TestDocumentUpdateByFileApiPost: doc_id = str(uuid.uuid4()) data = {"file": (BytesIO(b"file content"), "test.pdf", "application/pdf")} with app.test_request_context( - f"/datasets/{mock_dataset.id}/documents/{doc_id}/update_by_file", - method="POST", + f"/datasets/{mock_dataset.id}/documents/{doc_id}", + method="PATCH", content_type="multipart/form-data", data=data, headers={"Authorization": "Bearer test_token"}, ): - api = DocumentUpdateByFileApi() - response, status = api.post( + api = DocumentApi() + response, status = api.patch( tenant_id=mock_tenant.id, dataset_id=mock_dataset.id, document_id=doc_id, diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py index 95c2f5cf92..a26cdf6563 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py @@ -171,6 +171,113 @@ class TestHitTestingApiPost: assert passed_retrieval_model["search_method"] == "semantic_search" assert passed_retrieval_model["top_k"] == 10 + @patch("controllers.service_api.dataset.hit_testing.service_api_ns") + @patch("controllers.console.datasets.hit_testing_base.marshal") + @patch("controllers.console.datasets.hit_testing_base.HitTestingService") + @patch("controllers.console.datasets.hit_testing_base.DatasetService") + @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) + def test_post_preserves_retrieval_model_metadata_filtering_conditions( + self, + mock_current_user, + mock_dataset_svc, + mock_hit_svc, + mock_marshal, + mock_ns, + app, + ): + """Service API retrieval payload should not drop metadata filters.""" + dataset_id = str(uuid.uuid4()) + tenant_id = str(uuid.uuid4()) + + mock_dataset = Mock() + mock_dataset.id = dataset_id + + mock_dataset_svc.get_dataset.return_value = mock_dataset + mock_dataset_svc.check_dataset_permission.return_value = None + mock_hit_svc.retrieve.return_value = {"query": "filtered query", "records": []} + mock_hit_svc.hit_testing_args_check.return_value = None + mock_marshal.return_value = [] + + metadata_filtering_conditions = { + "logical_operator": "and", + "conditions": [ + { + "name": "category", + "comparison_operator": "is", + "value": "finance", + } + ], + } + mock_ns.payload = { + "query": "filtered query", + "retrieval_model": { + "search_method": "semantic_search", + "reranking_enable": False, + "score_threshold_enabled": False, + "top_k": 4, + "metadata_filtering_conditions": metadata_filtering_conditions, + }, + } + + with app.test_request_context(): + api = HitTestingApi() + HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) + + passed_retrieval_model = mock_hit_svc.retrieve.call_args.kwargs.get("retrieval_model") + assert passed_retrieval_model is not None + assert passed_retrieval_model["metadata_filtering_conditions"] == metadata_filtering_conditions + + @patch("controllers.service_api.dataset.hit_testing.service_api_ns") + @patch("controllers.console.datasets.hit_testing_base.marshal") + @patch("controllers.console.datasets.hit_testing_base.HitTestingService") + @patch("controllers.console.datasets.hit_testing_base.DatasetService") + @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) + def test_post_normalizes_legacy_query_and_nullable_list_fields( + self, + mock_current_user, + mock_dataset_svc, + mock_hit_svc, + mock_marshal, + mock_ns, + app, + ): + """Test service API normalizes legacy query shape and nullable list fields.""" + dataset_id = str(uuid.uuid4()) + tenant_id = str(uuid.uuid4()) + + mock_dataset = Mock() + mock_dataset.id = dataset_id + + mock_dataset_svc.get_dataset.return_value = mock_dataset + mock_dataset_svc.check_dataset_permission.return_value = None + + mock_hit_svc.retrieve.return_value = {"query": {"content": "legacy query"}, "records": ["placeholder"]} + mock_hit_svc.hit_testing_args_check.return_value = None + mock_marshal.return_value = [ + { + "segment": {"id": "segment-1", "keywords": None}, + "child_chunks": None, + "files": None, + "score": 0.9, + } + ] + + mock_ns.payload = {"query": "legacy query"} + + with app.test_request_context(): + api = HitTestingApi() + response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) + + assert response["query"] == "legacy query" + assert response["records"] == [ + { + "segment": {"id": "segment-1", "keywords": []}, + "child_chunks": [], + "files": [], + "score": 0.9, + } + ] + @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.DatasetService") @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py index b93a1cf14b..b7e24f9201 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py @@ -19,6 +19,7 @@ import uuid from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.service_api.dataset.metadata import ( @@ -76,7 +77,7 @@ class TestDatasetMetadataCreatePost: mock_dataset_svc, mock_meta_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -106,7 +107,7 @@ class TestDatasetMetadataCreatePost: def test_create_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -136,7 +137,7 @@ class TestDatasetMetadataCreateGet: self, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -160,7 +161,7 @@ class TestDatasetMetadataCreateGet: def test_get_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -201,7 +202,7 @@ class TestDatasetMetadataServiceApiPatch: mock_dataset_svc, mock_meta_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -232,7 +233,7 @@ class TestDatasetMetadataServiceApiPatch: def test_update_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -273,7 +274,7 @@ class TestDatasetMetadataServiceApiDelete: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -302,7 +303,7 @@ class TestDatasetMetadataServiceApiDelete: def test_delete_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -336,7 +337,7 @@ class TestDatasetMetadataBuiltInFieldGet: def test_get_built_in_fields_success( self, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -382,7 +383,7 @@ class TestDatasetMetadataBuiltInFieldAction: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -414,7 +415,7 @@ class TestDatasetMetadataBuiltInFieldAction: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -441,7 +442,7 @@ class TestDatasetMetadataBuiltInFieldAction: def test_action_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -485,7 +486,7 @@ class TestDocumentMetadataEditPost: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -513,7 +514,7 @@ class TestDocumentMetadataEditPost: def test_update_documents_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): diff --git a/api/tests/unit_tests/controllers/service_api/test_index.py b/api/tests/unit_tests/controllers/service_api/test_index.py index c560a3c698..8441118181 100644 --- a/api/tests/unit_tests/controllers/service_api/test_index.py +++ b/api/tests/unit_tests/controllers/service_api/test_index.py @@ -5,6 +5,7 @@ Unit tests for Service API Index endpoint from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.service_api.index import IndexApi @@ -13,7 +14,7 @@ class TestIndexApi: """Test suite for IndexApi resource.""" @patch("controllers.service_api.index.dify_config", autospec=True) - def test_get_returns_api_info(self, mock_config, app): + def test_get_returns_api_info(self, mock_config, app: Flask): """Test that GET returns API metadata with correct structure.""" # Arrange mock_config.project.version = "1.0.0-test" @@ -32,7 +33,7 @@ class TestIndexApi: assert response["api_version"] == "v1" assert response["server_version"] == "1.0.0-test" - def test_get_response_has_required_fields(self, app): + def test_get_response_has_required_fields(self, app: Flask): """Test that response contains all required fields.""" # Arrange mock_config = MagicMock() diff --git a/api/tests/unit_tests/controllers/service_api/test_site.py b/api/tests/unit_tests/controllers/service_api/test_site.py deleted file mode 100644 index c0b40d070a..0000000000 --- a/api/tests/unit_tests/controllers/service_api/test_site.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -Unit tests for Service API Site controller -""" - -import uuid -from unittest.mock import Mock, patch - -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.service_api.app.site import AppSiteApi -from models.account import TenantStatus -from models.model import App, Site -from tests.unit_tests.conftest import setup_mock_tenant_account_query - - -class TestAppSiteApi: - """Test suite for AppSiteApi""" - - @pytest.fixture - def mock_app_model(self): - """Create a mock App model with tenant.""" - app = Mock(spec=App) - app.id = str(uuid.uuid4()) - app.tenant_id = str(uuid.uuid4()) - app.status = "normal" - app.enable_api = True - - mock_tenant = Mock() - mock_tenant.id = app.tenant_id - mock_tenant.status = TenantStatus.NORMAL - app.tenant = mock_tenant - - return app - - @pytest.fixture - def mock_site(self): - """Create a mock Site model.""" - site = Mock(spec=Site) - site.id = str(uuid.uuid4()) - site.app_id = str(uuid.uuid4()) - site.title = "Test Site" - site.icon = "icon-url" - site.icon_background = "#ffffff" - site.description = "Site description" - site.copyright = "Copyright 2024" - site.privacy_policy = "Privacy policy text" - site.custom_disclaimer = "Custom disclaimer" - site.default_language = "en-US" - site.prompt_public = True - site.show_workflow_steps = True - site.use_icon_as_answer_icon = False - site.chat_color_theme = "light" - site.chat_color_theme_inverted = False - site.icon_type = "image" - site.created_at = "2024-01-01T00:00:00" - site.updated_at = "2024-01-01T00:00:00" - return site - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_success( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - mock_site, - ): - """Test successful retrieval of site configuration.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - # Mock wraps.db for authentication - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site.db for site query - mock_db.session.scalar.return_value = mock_site - - # Act - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - response = api.get() - - # Assert - assert response["title"] == "Test Site" - assert response["icon"] == "icon-url" - assert response["description"] == "Site description" - mock_db.session.scalar.assert_called_once() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_not_found( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - ): - """Test that Forbidden is raised when site is not found.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site query to return None - mock_db.session.scalar.return_value = None - - # Act & Assert - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - with pytest.raises(Forbidden): - api.get() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_tenant_archived( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - mock_site, - ): - """Test that Forbidden is raised when tenant is archived.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site query - mock_db.session.scalar.return_value = mock_site - - # Set tenant status to archived AFTER authentication - mock_app_model.tenant.status = TenantStatus.ARCHIVE - - # Act & Assert - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - with pytest.raises(Forbidden): - api.get() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_queries_by_app_id( - self, mock_wraps_db, mock_validate_token, mock_current_app, mock_db, mock_user_logged_in, app, mock_app_model - ): - """Test that site is queried using the app model's id.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - mock_site = Mock(spec=Site) - mock_site.id = str(uuid.uuid4()) - mock_site.app_id = mock_app_model.id - mock_site.title = "Test Site" - mock_site.icon = "icon-url" - mock_site.icon_background = "#ffffff" - mock_site.description = "Site description" - mock_site.copyright = "Copyright 2024" - mock_site.privacy_policy = "Privacy policy text" - mock_site.custom_disclaimer = "Custom disclaimer" - mock_site.default_language = "en-US" - mock_site.prompt_public = True - mock_site.show_workflow_steps = True - mock_site.use_icon_as_answer_icon = False - mock_site.chat_color_theme = "light" - mock_site.chat_color_theme_inverted = False - mock_site.icon_type = "image" - mock_site.created_at = "2024-01-01T00:00:00" - mock_site.updated_at = "2024-01-01T00:00:00" - mock_db.session.scalar.return_value = mock_site - - # Act - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - api.get() - - # Assert - # The query was executed successfully (site returned), which validates the correct query was made - mock_db.session.scalar.assert_called_once() diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index a2008e024b..30d7b92913 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -24,8 +24,8 @@ from enums.cloud_plan import CloudPlan from models.account import TenantStatus from models.model import ApiToken from tests.unit_tests.conftest import ( - setup_mock_dataset_tenant_query, - setup_mock_tenant_account_query, + setup_mock_dataset_owner_execute_result, + setup_mock_tenant_owner_execute_result, ) @@ -39,7 +39,7 @@ class TestValidateAndGetApiToken: app.config["TESTING"] = True return app - def test_missing_authorization_header(self, app): + def test_missing_authorization_header(self, app: Flask): """Test that Unauthorized is raised when Authorization header is missing.""" # Arrange with app.test_request_context("/", method="GET"): @@ -50,7 +50,7 @@ class TestValidateAndGetApiToken: validate_and_get_api_token("app") assert "Authorization header must be provided" in str(exc_info.value) - def test_invalid_auth_scheme(self, app): + def test_invalid_auth_scheme(self, app: Flask): """Test that Unauthorized is raised when auth scheme is not Bearer.""" # Arrange with app.test_request_context("/", method="GET", headers={"Authorization": "Basic token123"}): @@ -62,7 +62,7 @@ class TestValidateAndGetApiToken: @patch("controllers.service_api.wraps.record_token_usage") @patch("controllers.service_api.wraps.ApiTokenCache") @patch("controllers.service_api.wraps.fetch_token_with_single_flight") - def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app): + def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask): """Test that valid token returns the ApiToken object.""" # Arrange mock_api_token = Mock(spec=ApiToken) @@ -84,7 +84,7 @@ class TestValidateAndGetApiToken: @patch("controllers.service_api.wraps.record_token_usage") @patch("controllers.service_api.wraps.ApiTokenCache") @patch("controllers.service_api.wraps.fetch_token_with_single_flight") - def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app): + def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask): """Test that invalid token raises Unauthorized.""" # Arrange from werkzeug.exceptions import Unauthorized @@ -141,14 +141,11 @@ class TestValidateAppToken: mock_account = Mock() mock_account.id = str(uuid.uuid4()) - mock_ta = Mock() - mock_ta.account_id = mock_account.id - # Use side_effect to return app first, then tenant via session.get() mock_db.session.get.side_effect = [mock_app, mock_tenant] - # Mock the tenant owner query (execute(select(...)).one_or_none()) - setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) + # Mock the tenant owner execute result (execute(select(...)).one_or_none()) + setup_mock_tenant_owner_execute_result(mock_db, mock_tenant, mock_account) @validate_app_token def protected_view(app_model): @@ -164,7 +161,7 @@ class TestValidateAppToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app): + def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app: Flask): """Test that Forbidden is raised when app no longer exists.""" # Arrange mock_api_token = Mock() @@ -185,7 +182,7 @@ class TestValidateAppToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app): + def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app: Flask): """Test that Forbidden is raised when app status is abnormal.""" # Arrange mock_api_token = Mock() @@ -208,7 +205,7 @@ class TestValidateAppToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app): + def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app: Flask): """Test that Forbidden is raised when app API is disabled.""" # Arrange mock_api_token = Mock() @@ -243,7 +240,7 @@ class TestCloudEditionBillingResourceCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app): + def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app: Flask): """Test that request is allowed when under resource limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -267,7 +264,7 @@ class TestCloudEditionBillingResourceCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app): + def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app: Flask): """Test that Forbidden is raised when at resource limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -290,7 +287,7 @@ class TestCloudEditionBillingResourceCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app): + def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app: Flask): """Test that request is allowed when billing is disabled.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -323,7 +320,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app): + def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask): """Test that add_segment is rejected in SANDBOX plan.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -345,7 +342,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app): + def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask): """Test that non-add_segment operations are allowed in SANDBOX.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -379,7 +376,7 @@ class TestCloudEditionBillingRateLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit") - def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app): + def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app: Flask): """Test that request is allowed when within rate limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -409,7 +406,7 @@ class TestCloudEditionBillingRateLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit") @patch("controllers.service_api.wraps.db") - def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app): + def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app: Flask): """Test that Forbidden is raised when over rate limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -448,7 +445,7 @@ class TestValidateDatasetToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.current_app") - def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app): + def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app: Flask): """Test that valid dataset token allows access.""" # Arrange # Use standard Mock for login_manager @@ -471,7 +468,7 @@ class TestValidateDatasetToken: mock_account.current_tenant = mock_tenant # Mock the tenant account join query (execute(select(...)).one_or_none()) - setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta) + setup_mock_dataset_owner_execute_result(mock_db, mock_tenant, mock_ta) # Mock the account lookup via session.get() mock_db.session.get.return_value = mock_account @@ -490,7 +487,7 @@ class TestValidateDatasetToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app): + def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app: Flask): """Test that NotFound is raised when dataset doesn't exist.""" # Arrange mock_api_token = Mock() diff --git a/api/tests/unit_tests/controllers/web/conftest.py b/api/tests/unit_tests/controllers/web/conftest.py index 274d78c9cf..b7f3244c6c 100644 --- a/api/tests/unit_tests/controllers/web/conftest.py +++ b/api/tests/unit_tests/controllers/web/conftest.py @@ -22,18 +22,16 @@ class FakeSession: def __init__(self, mapping: dict[str, Any] | None = None): self._mapping: dict[str, Any] = mapping or {} - self._model_name: str | None = None - def query(self, model: type) -> FakeSession: - self._model_name = model.__name__ - return self + def get(self, model: type, _ident: object) -> Any: + return self._mapping.get(model.__name__) - def where(self, *_args: object, **_kwargs: object) -> FakeSession: - return self - - def first(self) -> Any: - assert self._model_name is not None - return self._mapping.get(self._model_name) + def scalar(self, stmt: Any) -> Any: + try: + model = stmt.column_descriptions[0]["entity"] + except (AttributeError, IndexError, KeyError, TypeError): + return None + return self._mapping.get(model.__name__) class FakeDB: diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index cbfc8fa613..a6ca441801 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.audio import AudioApi, TextApi from controllers.web.error import ( @@ -22,6 +21,7 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index 49039d03fe..4f8d848637 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -7,7 +7,6 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from controllers.web.error import ( @@ -19,6 +18,7 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py index a1dbc80b20..5f2dc19aab 100644 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -36,18 +36,6 @@ class _FakeSession: def __init__(self, mapping: dict[str, Any]): self._mapping = mapping - self._model_name: str | None = None - - def query(self, model): - self._model_name = model.__name__ - return self - - def where(self, *args, **kwargs): - return self - - def first(self): - assert self._model_name is not None - return self._mapping.get(self._model_name) def get(self, model, ident): return self._mapping.get(model.__name__) diff --git a/api/tests/unit_tests/controllers/web/test_message_endpoints.py b/api/tests/unit_tests/controllers/web/test_message_endpoints.py index 89ab93d8d4..da88b109a8 100644 --- a/api/tests/unit_tests/controllers/web/test_message_endpoints.py +++ b/api/tests/unit_tests/controllers/web/test_message_endpoints.py @@ -129,12 +129,6 @@ class TestMessageSuggestedQuestionApi: with pytest.raises(NotChatAppError): MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) - def test_wrong_mode_raises(self, app: Flask) -> None: - msg_id = uuid4() - with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): - with pytest.raises(NotChatAppError): - MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) - @patch("controllers.web.message.MessageService.get_suggested_questions_after_answer") def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None: msg_id = uuid4() diff --git a/api/tests/unit_tests/controllers/web/test_pydantic_models.py b/api/tests/unit_tests/controllers/web/test_pydantic_models.py index dcf8133712..bceb65b89f 100644 --- a/api/tests/unit_tests/controllers/web/test_pydantic_models.py +++ b/api/tests/unit_tests/controllers/web/test_pydantic_models.py @@ -198,7 +198,7 @@ class TestMessageListQuery: assert q.limit == 20 def test_invalid_conversation_id(self) -> None: - with pytest.raises(ValidationError, match="not a valid uuid"): + with pytest.raises(ValidationError, match="must be a valid UUID"): MessageListQuery(conversation_id="bad") def test_limit_bounds(self) -> None: @@ -216,7 +216,7 @@ class TestMessageListQuery: def test_invalid_first_id(self) -> None: cid = str(uuid4()) - with pytest.raises(ValidationError, match="not a valid uuid"): + with pytest.raises(ValidationError, match="must be a valid UUID"): MessageListQuery(conversation_id=cid, first_id="invalid") diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index 0661c02578..13b953c04d 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -4,9 +4,12 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from jwt import InvalidTokenError +from werkzeug.exceptions import Unauthorized import services.errors.account from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi +from services.entities.auth_entities import LoginFailureReason def encode_code(code: str) -> str: @@ -31,7 +34,6 @@ def _patch_wraps(): patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), patch("controllers.web.login.dify_config", web_dify), ): - mock_db.session.query.return_value.first.return_value = MagicMock() yield @@ -115,13 +117,18 @@ class TestLoginApi: def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None: from controllers.console.error import AccountBannedError - with app.test_request_context( - "/web/login", - method="POST", - json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, - ): - with pytest.raises(AccountBannedError): - LoginApi().post() + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AccountBannedError): + LoginApi().post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED @patch( "controllers.web.login.WebAppAuthService.authenticate", @@ -130,13 +137,87 @@ class TestLoginApi: def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None: from controllers.console.auth.error import AuthenticationFailedError - with app.test_request_context( - "/web/login", - method="POST", - json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, - ): - with pytest.raises(AuthenticationFailedError): - LoginApi().post() + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AuthenticationFailedError): + LoginApi().post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS + + @patch( + "controllers.web.login.WebAppAuthService.authenticate", + side_effect=services.errors.account.AccountNotFoundError(), + ) + def test_login_account_not_found(self, mock_auth: MagicMock, app: Flask) -> None: + from controllers.console.auth.error import AuthenticationFailedError + + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "missing@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AuthenticationFailedError): + LoginApi().post() + + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "missing@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_NOT_FOUND + + @patch("controllers.web.login.WebAppAuthService.get_email_code_login_data", return_value=None) + def test_email_code_login_logs_invalid_token(self, mock_get_token_data: MagicMock, app: Flask) -> None: + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/email-code-login/validity", + method="POST", + json={"email": "user@example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + with pytest.raises(InvalidTokenError): + EmailCodeLoginApi().post() + + mock_get_token_data.assert_called_once_with("token-123") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_EMAIL_CODE_TOKEN + + @patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token") + @patch( + "controllers.web.login.WebAppAuthService.get_user_through_email", + side_effect=Unauthorized("Account is banned."), + ) + @patch( + "controllers.web.login.WebAppAuthService.get_email_code_login_data", + return_value={"email": "User@Example.com", "code": "123456"}, + ) + def test_email_code_login_logs_banned_account( + self, + mock_get_token_data: MagicMock, + mock_get_user: MagicMock, + mock_revoke_token: MagicMock, + app: Flask, + ) -> None: + from controllers.console.error import AccountBannedError + + with patch("controllers.web.login.logger.warning") as mock_log_warning: + with app.test_request_context( + "/web/email-code-login/validity", + method="POST", + json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + with pytest.raises(AccountBannedError): + EmailCodeLoginApi().post() + + mock_get_token_data.assert_called_once_with("token-123") + mock_revoke_token.assert_called_once_with("token-123") + assert mock_log_warning.call_count == 1 + assert mock_log_warning.call_args.args[1] == "user@example.com" + assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED class TestLoginStatusApi: diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index bc7aea0ef9..cde8820e00 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -2,11 +2,11 @@ import json from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMUsage from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError +from graphon.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index 97206019b9..ea8cc8aa86 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -1,9 +1,9 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.agent.cot_chat_agent_runner import CotChatAgentRunner +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index defc8b4b64..2f5873d865 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -1,6 +1,8 @@ import json import pytest + +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -8,8 +10,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner - # ----------------------------- # Fixtures # ----------------------------- diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index a44a0650eb..17ab5babcb 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -3,6 +3,11 @@ from typing import Any from unittest.mock import MagicMock import pytest + +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, @@ -11,11 +16,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.agent.errors import AgentMaxIterationError -from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueMessageFileEvent - # ============================== # Dummy Helper Classes # ============================== diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py index 5ee66da94a..186b4a501d 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -2,8 +2,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.entities.model_entities import ModelStatus @@ -12,6 +10,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey class TestModelConfigConverter: diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py index e2f3c16335..d9fe7004ff 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -1,9 +1,9 @@ import pytest -from graphon.variables.input_entities import VariableEntityType from core.app.app_config.easy_ui_based_app.variables.manager import ( BasicVariablesConfigManager, ) +from graphon.variables.input_entities import VariableEntityType class TestBasicVariablesConfigManagerConvert: diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 8bde9c1f97..11b53dd0f9 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,8 +1,7 @@ +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager - def test_convert_with_vision(): config = { diff --git a/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py index dd00c3defc..0a0ffe657c 100644 --- a/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py +++ b/api/tests/unit_tests/core/app/app_config/features/test_additional_feature_managers.py @@ -77,6 +77,38 @@ class TestAdditionalFeatureManagers: SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( {"suggested_questions_after_answer": {"enabled": "yes"}} ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": True, "prompt": 123}} + ) + with pytest.raises(ValueError, match="must be less than or equal to 1000 characters"): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": True, "prompt": "a" * 1001}} + ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": True, "model": "bad"}} + ) + with pytest.raises(ValueError): + SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + {"suggested_questions_after_answer": {"enabled": True, "model": {"provider": "openai"}}} + ) + + validated_config, _ = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( + { + "suggested_questions_after_answer": { + "enabled": True, + "prompt": "custom prompt", + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "completion_params": {"max_tokens": 1024}, + }, + } + } + ) + assert validated_config["suggested_questions_after_answer"]["prompt"] == "custom prompt" + assert validated_config["suggested_questions_after_answer"]["model"]["name"] == "gpt-4o-mini" assert ( SuggestedQuestionsAfterAnswerConfigManager.convert({"suggested_questions_after_answer": {"enabled": True}}) diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py index 000f83cd5a..f2bc3076da 100644 --- a/api/tests/unit_tests/core/app/app_config/test_entities.py +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -1,10 +1,10 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.app_config.entities import ( DatasetRetrieveConfigEntity, PromptTemplateEntity, ) +from graphon.variables.input_entities import VariableEntity, VariableEntityType class TestAppConfigEntities: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 1fb0dc6cf1..370f7abb8b 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 -from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from factories import variable_factory +from graphon.variables import SegmentType from models import ConversationVariable, Workflow MINIMAL_GRAPH = { @@ -154,7 +154,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock GraphRuntimeState to accept the variable pool @@ -301,7 +300,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock ConversationVariable.from_variable to return mock objects @@ -453,7 +451,6 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_session_class.return_value.__enter__.return_value = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock() # App exists mock_db.engine = MagicMock() # Mock GraphRuntimeState to accept the variable pool diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index e9fdeefee4..6debeb4fdd 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -1,9 +1,10 @@ from collections.abc import Generator -from graphon.enums import WorkflowNodeExecutionStatus +import pytest from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, ErrorStreamResponse, @@ -12,6 +13,8 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) +from graphon.entities.pause_reason import PauseReasonType +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: @@ -29,6 +32,37 @@ class TestAdvancedChatGenerateResponseConverter: response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) assert "usage" not in response["metadata"] + def test_blocking_full_response_derives_pause_data_from_model_dump(self, monkeypatch: pytest.MonkeyPatch): + data = AdvancedChatPausedBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + workflow_run_id="run-1", + answer="partial", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + paused_nodes=["node-1"], + reasons=[{"type": PauseReasonType.HUMAN_INPUT_REQUIRED, "form_id": "form-1"}], + status=WorkflowExecutionStatus.PAUSED, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ) + original_model_dump = type(data).model_dump + + def _model_dump_with_future_field(self, *args, **kwargs): + payload = original_model_dump(self, *args, **kwargs) + payload["future_field"] = "future-value" + return payload + + monkeypatch.setattr(type(data), "model_dump", _model_dump_with_future_field) + blocking = AdvancedChatPausedBlockingResponse(task_id="t1", data=data) + + response = AdvancedChatAppGenerateResponseConverter.convert_blocking_full_response(blocking) + + assert response["data"]["future_field"] == "future-value" + def test_stream_simple_response_includes_node_events(self): node_start = NodeStartStreamResponse( task_id="t1", diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index a6d8598955..99a386cd45 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -6,8 +6,6 @@ from types import SimpleNamespace from unittest import mock import pytest -from graphon.entities.pause_reason import HumanInputRequired -from graphon.enums import WorkflowExecutionStatus from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom @@ -19,6 +17,8 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import AppMode, EndUser diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 82b2e51019..64bcfa9a18 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -4,8 +4,6 @@ from contextlib import contextmanager from types import SimpleNamespace import pytest -from graphon.enums import BuiltinNodeTypes -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.advanced_chat.generate_task_pipeline import ( @@ -41,14 +39,20 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( + AdvancedChatPausedBlockingResponse, AnnotationReply, AnnotationReplyAccount, + HumanInputRequiredResponse, MessageAudioStreamResponse, MessageEndStreamResponse, PingStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables +from graphon.entities.pause_reason import PauseReasonType +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import UserAction +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus from models.model import AppMode, EndUser @@ -123,6 +127,57 @@ class TestAdvancedChatGenerateTaskPipeline: assert response.data.answer == "done" assert response.data.metadata == {"k": "v"} + def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self): + pipeline = _make_pipeline() + pipeline._task_state.answer = "partial answer" + pipeline._workflow_run_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + start_at=0.0, + total_tokens=7, + node_run_steps=3, + ) + + def _gen(): + yield HumanInputRequiredResponse( + task_id="task", + workflow_run_id="run-id", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Approval", + form_content="Need approval", + inputs=[], + actions=[UserAction(id="approve", title="Approve")], + display_in_ui=True, + form_token="token-1", + resolved_default_values={}, + expiration_time=123, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert isinstance(response, AdvancedChatPausedBlockingResponse) + assert response.data.workflow_run_id == "run-id" + assert response.data.status == "paused" + assert response.data.paused_nodes == ["node-1"] + assert response.data.reasons == [ + { + "TYPE": PauseReasonType.HUMAN_INPUT_REQUIRED, + "form_id": "form-1", + "node_id": "node-1", + "node_title": "Approval", + "form_content": "Need approval", + "inputs": [], + "actions": [{"id": "approve", "title": "Approve", "button_style": "default"}], + "display_in_ui": True, + "form_token": "token-1", + "resolved_default_values": {}, + "expiration_time": 123, + } + ] + def test_handle_text_chunk_event_updates_state(self): pipeline = _make_pipeline() pipeline._message_cycle_manager = SimpleNamespace( diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 7dc4358150..80f7f94b1a 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -1,12 +1,12 @@ import contextlib import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 08250bc3b6..4567b35480 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -1,10 +1,10 @@ import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 68bcffb0e8..8f3c41701b 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -2,7 +2,6 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_runner import ChatAppRunner @@ -10,6 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index f255d2c7df..b3ea1a464f 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from graphon.file import FileTransferMethod, FileType -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index 4a94a2b4f1..201923e0e4 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,11 +1,11 @@ from types import SimpleNamespace import pytest -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.workflow.system_variables import build_system_variables from core.workflow.variable_pool_initializer import add_variables_to_pool +from graphon.runtime import GraphRuntimeState, VariablePool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 328cd12f12..dd6cd0e919 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,10 +1,9 @@ from collections.abc import Mapping, Sequence +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from graphon.variables.segments import ArrayFileSegment, FileSegment -from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter - class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Test class for WorkflowResponseConverter._fetch_files_from_variable_value method""" @@ -12,8 +11,8 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: def create_test_file(self, file_id: str = "test_file_1") -> File: """Create a test File object""" return File( - id=file_id, - type=FileType.DOCUMENT, + file_id=file_id, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", filename=f"{file_id}.txt", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index bc11bf4174..1bef6f69cd 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,13 +1,12 @@ from datetime import UTC, datetime from types import SimpleNamespace -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index c9e146ff12..936ac37e55 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,11 +1,10 @@ from types import SimpleNamespace -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 0fde7565d2..b3c0eb74fa 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,8 +10,6 @@ from typing import Any from unittest.mock import Mock import pytest -from graphon.entities import WorkflowStartReason -from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -27,6 +25,8 @@ from core.app.entities.queue_entities import ( QueueNodeSucceededEvent, ) from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index 619d66085a..aa2085177e 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index 96af9fbdee..f2e35f9900 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -3,13 +3,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index 6cdcab29ab..cfe797aa76 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus - from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -12,6 +10,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py index 9a2dc38f74..c36edf48fc 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py @@ -375,7 +375,7 @@ def test_generate_success_returns_converted(generator, mocker): workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={}) session = MagicMock() - session.query.return_value.where.return_value.first.return_value = workflow + session.get.return_value = workflow mocker.patch.object(module.db, "session", session) queue_manager = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 4fe82efcb3..9db83f5531 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -1,5 +1,4 @@ import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult import core.app.apps.pipeline.pipeline_queue_manager as module from core.app.apps.base_app_queue_manager import PublishFrom @@ -14,6 +13,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) +from graphon.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index c8ae288e6f..603062a51c 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -22,11 +22,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.graph_events import GraphRunFailedEvent import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from graphon.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: @@ -132,11 +132,8 @@ def test_run_pipeline_not_found(mocker): app_generate_entity.single_iteration_run = None app_generate_entity.single_loop_run = None - query = MagicMock() - query.where.return_value.first.return_value = None - session = MagicMock() - session.query.return_value = query + session.get.side_effect = [None, None] mocker.patch.object(module.db, "session", session) runner = PipelineRunner( @@ -157,11 +154,9 @@ def test_run_workflow_not_initialized(mocker): app_generate_entity = _build_app_generate_entity() pipeline = MagicMock(id="pipe") - query_pipeline = MagicMock() - query_pipeline.where.return_value.first.return_value = pipeline session = MagicMock() - session.query.return_value = query_pipeline + session.get.side_effect = [None, pipeline] mocker.patch.object(module.db, "session", session) runner = PipelineRunner( diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py new file mode 100644 index 0000000000..560652f8cb --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generate_response_converter.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from collections.abc import Generator + +from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.task_entities import ( + AppStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) +from graphon.enums import WorkflowExecutionStatus + + +class _DummyConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]): + blocking_full_calls: list[WorkflowAppBlockingResponse] = [] + blocking_simple_calls: list[WorkflowAppBlockingResponse] = [] + stream_full_calls: list[Generator[AppStreamResponse, None, None]] = [] + stream_simple_calls: list[Generator[AppStreamResponse, None, None]] = [] + + @classmethod + def reset(cls) -> None: + cls.blocking_full_calls = [] + cls.blocking_simple_calls = [] + cls.stream_full_calls = [] + cls.stream_simple_calls = [] + + @classmethod + def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: + cls.blocking_full_calls.append(blocking_response) + return {"kind": "blocking-full", "task_id": blocking_response.task_id} + + @classmethod + def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]: + cls.blocking_simple_calls.append(blocking_response) + return {"kind": "blocking-simple", "task_id": blocking_response.task_id} + + @classmethod + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + cls.stream_full_calls.append(stream_response) + yield {"kind": "stream-full"} + + @classmethod + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[dict | str, None, None]: + cls.stream_simple_calls.append(stream_response) + yield {"kind": "stream-simple"} + + +def _build_blocking_response() -> WorkflowAppBlockingResponse: + return WorkflowAppBlockingResponse( + task_id="task-1", + workflow_run_id="run-1", + data=WorkflowAppBlockingResponse.Data( + id="run-1", + workflow_id="workflow-1", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"ok": True}, + error=None, + elapsed_time=0.1, + total_tokens=0, + total_steps=1, + created_at=1, + finished_at=2, + ), + ) + + +def _build_stream_response() -> Generator[AppStreamResponse, None, None]: + yield WorkflowAppStreamResponse( + workflow_run_id="run-1", + stream_response=PingStreamResponse(task_id="task-1"), + ) + + +def test_convert_routes_blocking_response_by_invoke_from() -> None: + _DummyConverter.reset() + blocking_response = _build_blocking_response() + + full_result = _DummyConverter.convert(blocking_response, InvokeFrom.SERVICE_API) + simple_result = _DummyConverter.convert(blocking_response, InvokeFrom.WEB_APP) + + assert full_result == {"kind": "blocking-full", "task_id": "task-1"} + assert simple_result == {"kind": "blocking-simple", "task_id": "task-1"} + assert _DummyConverter.blocking_full_calls == [blocking_response] + assert _DummyConverter.blocking_simple_calls == [blocking_response] + + +def test_convert_routes_stream_response_by_invoke_from() -> None: + _DummyConverter.reset() + + full_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.SERVICE_API)) + simple_result = list(_DummyConverter.convert(_build_stream_response(), InvokeFrom.WEB_APP)) + + assert full_result == [{"kind": "stream-full"}] + assert simple_result == [{"kind": "stream-simple"}] + assert len(_DummyConverter.stream_full_calls) == 1 + assert len(_DummyConverter.stream_simple_calls) == 1 diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index 6167be3bbd..b0f8b423e1 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,7 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -476,9 +476,8 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): - from graphon.enums import BuiltinNodeTypes - from core.app.entities.app_invoke_entities import InvokeFrom + from graphon.enums import BuiltinNodeTypes from models import Account base_app_generator = BaseAppGenerator() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index 1dee7fdab6..17de39ca99 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -4,15 +4,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessageRole, - TextPromptMessageContent, -) -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from core.app.app_config.entities import ( AdvancedChatMessageEntity, @@ -23,6 +14,15 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_message_generator.py b/api/tests/unit_tests/core/app/apps/test_message_generator.py index 25377e633e..90c9abf35c 100644 --- a/api/tests/unit_tests/core/app/apps/test_message_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_message_generator.py @@ -1,6 +1,7 @@ from unittest.mock import Mock, patch from core.app.apps.message_generator import MessageGenerator +from core.app.entities.task_entities import StreamEvent from models.model import AppMode @@ -23,7 +24,21 @@ class TestMessageGenerator: "core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}]) ) as mock_stream, ): - events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2)) + events = list( + MessageGenerator.retrieve_events( + AppMode.WORKFLOW, + "run-1", + idle_timeout=1, + ping_interval=2, + terminal_events=[StreamEvent.WORKFLOW_FINISHED.value], + ) + ) assert events == [{"event": "ping"}] - mock_stream.assert_called_once() + mock_stream.assert_called_once_with( + topic="topic", + idle_timeout=1, + ping_interval=2, + on_subscribe=None, + terminal_events=[StreamEvent.WORKFLOW_FINISHED.value], + ) diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index a126bc85f7..6104b8d6ca 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -4,9 +4,14 @@ from types import ModuleType, SimpleNamespace from typing import Any import graphon.nodes.human_input.entities # noqa: F401 +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow import node_factory as node_factory_module +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.entities.base_node_data import BaseNodeData, RetryConfig -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.entities.pause_reason import SchedulingPause from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus from graphon.graph import Graph @@ -25,12 +30,6 @@ from graphon.nodes.base.node import Node from graphon.nodes.end.entities import EndNodeData from graphon.nodes.start.entities import StartNodeData from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: @@ -56,8 +55,21 @@ class _StubToolNode(Node[_StubToolNodeData]): def version(cls) -> str: return "1" - def init_node_data(self, data): - self._node_data = _StubToolNodeData.model_validate(data) + def __init__( + self, + node_id: str, + config: _StubToolNodeData, + *, + graph_init_params, + graph_runtime_state, + **_kwargs: Any, + ) -> None: + super().__init__( + node_id=node_id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) def _get_error_strategy(self): return self._node_data.error_strategy @@ -90,21 +102,14 @@ class _StubToolNode(Node[_StubToolNodeData]): def _patch_tool_node(mocker): - original_create_node = DifyNodeFactory.create_node + original_resolve_node_class = node_factory_module.resolve_workflow_node_class - def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) - node_data = typed_node_config["data"] - if node_data.type == BuiltinNodeTypes.TOOL: - return _StubToolNode( - id=str(typed_node_config["id"]), - config=typed_node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return original_create_node(self, typed_node_config) + def _patched_resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + if node_type == BuiltinNodeTypes.TOOL: + return _StubToolNode + return original_resolve_node_class(node_type=node_type, node_version=node_version) - mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) + mocker.patch.object(node_factory_module, "resolve_workflow_node_class", side_effect=_patched_resolve_node_class) def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py index a7714c56ce..58f0e47a4b 100644 --- a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py +++ b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py @@ -88,6 +88,10 @@ def test_normalize_terminal_events_defaults(): } +def test_normalize_terminal_events_empty_values(): + assert _normalize_terminal_events([]) == set({}) + + def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch): topic = FakeTopic() times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0] @@ -106,3 +110,21 @@ def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch): assert next(generator) == StreamEvent.PING.value # next receive yields None -> ping interval triggers assert next(generator) == StreamEvent.PING.value + + +def test_stream_topic_events_can_continue_past_pause(): + topic = FakeTopic() + topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_PAUSED.value}).encode()) + topic.publish(json.dumps({"event": StreamEvent.WORKFLOW_FINISHED.value}).encode()) + + generator = stream_topic_events( + topic=topic, + idle_timeout=1.0, + terminal_events=[StreamEvent.WORKFLOW_FINISHED.value], + ) + + assert next(generator) == StreamEvent.PING.value + assert next(generator)["event"] == StreamEvent.WORKFLOW_PAUSED.value + assert next(generator)["event"] == StreamEvent.WORKFLOW_FINISHED.value + with pytest.raises(StopIteration): + next(generator) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index de5bca161c..58c7bfa4bc 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -4,6 +4,23 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest + +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, + QueueTextChunkEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.system_variables import default_system_variables from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -24,23 +41,6 @@ from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.variables import StringVariable -from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.app.entities.queue_entities import ( - QueueAgentLogEvent, - QueueIterationCompletedEvent, - QueueLoopCompletedEvent, - QueueNodeExceptionEvent, - QueueNodeFailedEvent, - QueueNodeRetryEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowPausedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.system_variables import default_system_variables - class TestWorkflowBasedAppRunner: def test_resolve_user_from(self): diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index aa789d9ff3..10fb2271f4 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 9e30faecf2..620a153204 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,14 +4,14 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest -from graphon.entities.graph_config import NodeConfigDictAdapter -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.workflow.system_variables import default_system_variables +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 8a717e1dcc..a3ab379b66 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,11 +3,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.entities import WorkflowStartReason -from graphon.entities.pause_reason import HumanInputRequired -from graphon.graph_events import GraphRunPausedEvent -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -16,6 +11,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.account import Account from models.human_input import RecipientType diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index b768e813bd..7dd7ffd727 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -1,7 +1,5 @@ from collections.abc import Generator -from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus - from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( ErrorStreamResponse, @@ -11,6 +9,7 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 29df903aa8..1f6e7e12ef 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,15 +2,14 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock -from graphon.entities import WorkflowStartReason -from graphon.runtime import GraphRuntimeState - from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent from core.workflow.system_variables import build_system_variables +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode from tests.workflow_test_utils import build_test_variable_pool diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index d91bb85aee..0bcc1029b0 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -5,8 +5,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline @@ -38,15 +36,18 @@ from core.app.entities.queue_entities import ( ) from core.app.entities.task_entities import ( ErrorStreamResponse, + HumanInputRequiredResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, PingStreamResponse, + WorkflowAppPausedBlockingResponse, WorkflowFinishStreamResponse, - WorkflowPauseStreamResponse, WorkflowStartStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode, EndUser @@ -91,27 +92,50 @@ def _make_pipeline(): class TestWorkflowGenerateTaskPipeline: - def test_to_blocking_response_handles_pause(self): + def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self): pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + start_at=0.0, + total_tokens=5, + node_run_steps=2, + ) def _gen(): - yield WorkflowPauseStreamResponse( + yield HumanInputRequiredResponse( task_id="task", - workflow_run_id="run", - data=WorkflowPauseStreamResponse.Data( - workflow_run_id="run", - status=WorkflowExecutionStatus.PAUSED, - outputs={}, - created_at=1, - elapsed_time=0.1, - total_tokens=0, - total_steps=0, + workflow_run_id="run-id", + data=HumanInputRequiredResponse.Data( + form_id="form-1", + node_id="node-1", + node_title="Human Input", + form_content="content", + expiration_time=1, ), ) response = pipeline._to_blocking_response(_gen()) + assert isinstance(response, WorkflowAppPausedBlockingResponse) + assert response.workflow_run_id == "run-id" assert response.data.status == WorkflowExecutionStatus.PAUSED + assert response.data.created_at == 0 + assert response.data.paused_nodes == ["node-1"] + assert response.data.reasons == [ + { + "TYPE": "human_input_required", + "form_id": "form-1", + "node_id": "node-1", + "node_title": "Human Input", + "form_content": "content", + "inputs": [], + "actions": [], + "display_in_ui": False, + "form_token": None, + "resolved_default_values": {}, + "expiration_time": 1, + } + ] def test_to_blocking_response_handles_finish(self): pipeline = _make_pipeline() diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py index 014a0cba72..7c79780641 100644 --- a/api/tests/unit_tests/core/app/entities/test_task_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -1,11 +1,10 @@ -from graphon.enums import WorkflowNodeExecutionStatus - from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, StreamEvent, ) +from graphon.enums import WorkflowNodeExecutionStatus class TestTaskEntities: diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index a78c1b428f..ba55e8f695 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,6 +1,9 @@ from collections.abc import Sequence from unittest.mock import Mock +from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.command_channels import CommandChannel from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent @@ -8,10 +11,6 @@ from graphon.node_events import NodeRunResult from graphon.runtime import ReadOnlyGraphRuntimeState from graphon.variables import StringVariable from graphon.variables.segments import Segment, StringSegment - -from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from core.workflow.system_variables import SystemVariableKey -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from libs.datetime_utils import naive_utc_now diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035e64325b..539944d683 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,6 +4,16 @@ from time import time from unittest.mock import Mock import pytest + +from core.app.app_config.entities import WorkflowUIBasedAppConfig +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.layers.pause_state_persist_layer import ( + PauseStatePersistenceLayer, + WorkflowResumptionContext, + _AdvancedChatAppGenerateEntityWrapper, + _WorkflowGenerateEntityWrapper, +) +from core.workflow.system_variables import SystemVariableKey from graphon.entities.pause_reason import SchedulingPause from graphon.graph_engine.entities.commands import GraphEngineCommand from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -15,16 +25,6 @@ from graphon.graph_events import ( ) from graphon.runtime import ReadOnlyVariablePool from graphon.variables.segments import Segment - -from core.app.app_config.entities import WorkflowUIBasedAppConfig -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity -from core.app.layers.pause_state_persist_layer import ( - PauseStatePersistenceLayer, - WorkflowResumptionContext, - _AdvancedChatAppGenerateEntityWrapper, - _WorkflowGenerateEntityWrapper, -) -from core.workflow.system_variables import SystemVariableKey from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py index 95931f4f8b..12d49be0f1 100644 --- a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -1,6 +1,5 @@ -from graphon.graph_events import GraphRunPausedEvent - from core.app.layers.suspend_layer import SuspendLayer +from graphon.graph_events import GraphRunPausedEvent class TestSuspendLayer: diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py index 7cf6eb4f31..1ac9a4d8c0 100644 --- a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -1,8 +1,7 @@ from unittest.mock import Mock, patch -from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand - from core.app.layers.timeslice_layer import TimeSliceLayer +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import SchedulerCommand diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index aa9285789b..d3bd15b6f3 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -2,11 +2,10 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import Mock, patch -from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent -from graphon.runtime import VariablePool - from core.app.layers.trigger_post_layer import TriggerPostLayer from core.workflow.system_variables import build_system_variables +from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool from models.enums import WorkflowTriggerStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py index 58aa7d7478..c246f7b783 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.entities.queue_entities import QueueErrorEvent from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.errors.error import QuotaExceededError +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 4aaa10a81a..1c1bf391d3 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,8 +2,6 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -28,6 +26,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index f22602a400..a20d89d807 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -5,9 +5,6 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest -from graphon.file import FileTransferMethod -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from core.app.app_config.entities import ( AppAdditionalFeatures, @@ -41,6 +38,9 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AudioTrunk +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 31b7313066..595d716666 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -17,11 +17,11 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.file import FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from graphon.file import FileTransferMethod, FileType from models.model import MessageFile, UploadFile diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py index 29df7eea86..21c761c579 100644 --- a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -1,10 +1,9 @@ from types import SimpleNamespace from unittest.mock import patch -from graphon.model_runtime.entities.model_entities import ModelPropertyKey - from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager from core.app.app_config.entities import ModelConfigEntity +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from models.provider_ids import ModelProviderID diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index dc2d82ccd6..5c50cb78da 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -2,14 +2,14 @@ from datetime import UTC, datetime from unittest.mock import Mock import pytest -from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType -from graphon.node_events import NodeRunResult from core.app.workflow.layers.persistence import ( PersistenceWorkflowInfo, WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index 7be9d6ac1e..701863b927 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -8,13 +8,13 @@ from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, urlparse import pytest -from graphon.file import File, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.app.file_access import DatabaseFileAccessController, FileAccessScope from core.app.workflow import file_runtime from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime from core.workflow.file_reference import build_file_reference +from graphon.file import File, FileTransferMethod, FileType from models import ToolFile, UploadFile @@ -26,8 +26,8 @@ def _build_file( extension: str | None = None, ) -> File: return File( - id="file-id", - type=FileType.IMAGE, + file_id="file-id", + file_type=FileType.IMAGE, transfer_method=transfer_method, reference=reference, remote_url=remote_url, @@ -351,7 +351,7 @@ def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.M assert runtime.multimodal_send_format == "url" - with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + with patch.object(file_runtime.graphon_ssrf_proxy, "get", return_value="response") as mock_get: assert runtime.http_get("http://example", follow_redirects=False) == "response" mock_get.assert_called_once_with("http://example", follow_redirects=False) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index 8497261d45..30a068f4c5 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -1,15 +1,15 @@ from types import SimpleNamespace import pytest -from graphon.enums import BuiltinNodeTypes from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.workflow.node_factory import DifyNodeFactory +from graphon.enums import BuiltinNodeTypes class DummyNode: - def __init__(self, *, id, config, graph_init_params, graph_runtime_state, **kwargs): - self.id = id + def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs): + self.id = node_id self.config = config self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py index a47d3db6f5..82552470a9 100644 --- a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -2,9 +2,8 @@ from __future__ import annotations from types import SimpleNamespace -from graphon.enums import BuiltinNodeTypes - from core.app.workflow.layers.observability import ObservabilityLayer +from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerExtras: diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index d8a68f6d00..cacb4dd4fa 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -4,6 +4,10 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest + +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables from graphon.entities import WorkflowNodeExecution from graphon.entities.pause_reason import SchedulingPause from graphon.enums import ( @@ -29,10 +33,6 @@ from graphon.graph_events import ( from graphon.node_events import NodeRunResult from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity -from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.workflow.system_variables import SystemVariableKey, build_system_variables - class _RepoRecorder: def __init__(self) -> None: diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 5ff9774b52..7b433ab57b 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -301,6 +301,7 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() + from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -308,8 +309,6 @@ class TestAppGeneratorTTSPublisher: TextPromptMessageContent, ) - from core.app.entities.queue_entities import QueueAgentMessageEvent - chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( @@ -337,11 +336,10 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() + from core.app.entities.queue_entities import QueueAgentMessageEvent from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta from graphon.model_runtime.entities.message_entities import AssistantPromptMessage - from core.app.entities.queue_entities import QueueAgentMessageEvent - chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index d338cadb77..deeac49bbc 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,15 +2,15 @@ import types from collections.abc import Generator import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.workflow.file_reference import parse_file_reference +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: @@ -430,7 +430,7 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", extension=".png", @@ -530,7 +530,7 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", extension=".pdf", diff --git a/api/tests/unit_tests/core/datasource/test_notion_provider.py b/api/tests/unit_tests/core/datasource/test_notion_provider.py index e4bd7d3bdf..d21b9e471b 100644 --- a/api/tests/unit_tests/core/datasource/test_notion_provider.py +++ b/api/tests/unit_tests/core/datasource/test_notion_provider.py @@ -775,9 +775,6 @@ class TestNotionExtractorLastEditedTime: "last_edited_time": "2024-11-27T18:00:00.000Z", } mock_request.return_value = mock_response - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query # Act extractor_page.update_last_edited_time(mock_document_model) @@ -863,9 +860,6 @@ class TestNotionExtractorIntegration: } mock_request.side_effect = [last_edited_response, block_response] - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query # Act documents = extractor.extract() @@ -919,10 +913,6 @@ class TestNotionExtractorIntegration: } mock_post.return_value = database_response - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter_by.return_value = mock_query - # Act documents = extractor.extract() diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index fbaf6d497d..0fca43cd0b 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -1,10 +1,10 @@ from unittest.mock import MagicMock, patch import pytest -from graphon.file import File, FileTransferMethod, FileType from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer +from graphon.file import File, FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index ff9fd0d8f3..ef8f360dbf 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -1,12 +1,11 @@ -from graphon.nodes.human_input.entities import FormInput, UserAction -from graphon.nodes.human_input.enums import FormInputType - from core.entities.execution_extra_content import ( ExecutionExtraContentDomainModel, HumanInputContent, HumanInputFormDefinition, HumanInputFormSubmissionData, ) +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index 2acd278a31..aeca2e3afd 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -8,9 +8,6 @@ drive provider mapping behavior. """ import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.entities.model_entities import ( DefaultModelEntity, @@ -19,6 +16,9 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: @@ -46,7 +46,7 @@ def test_simple_model_provider_entity_maps_from_provider_entity() -> None: # Assert assert simple_provider.provider == "openai" - assert simple_provider.label.en_US == "OpenAI" + assert simple_provider.label.en_us == "OpenAI" assert simple_provider.supported_model_types == [ModelType.LLM] diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 8cf0409c4c..a28143026f 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -6,17 +6,6 @@ from typing import Any from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FieldModelSchema, - FormType, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderEntity, -) from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus @@ -35,6 +24,17 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID @@ -345,22 +345,26 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) ] ) - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="encrypted-old-key" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): - with patch( - "core.entities.provider_configuration.encrypter.encrypt_token", - side_effect=lambda tenant_id, value: f"enc::{value}", - ): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, - credential_id="credential-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): + with patch( + "core.entities.provider_configuration.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc::{value}", + ): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, + credential_id="credential-1", + ) assert validated["openai_api_key"] == "enc::restored-key" assert validated["region"] == "us" @@ -370,23 +374,15 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: ) -def test_validate_provider_credentials_opens_session_when_not_passed() -> None: +def test_validate_provider_credentials_without_credential_id() -> None: configuration = _build_provider_configuration() - mock_session = Mock() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.Session") as mock_session_cls: - with patch("core.entities.provider_configuration.db") as mock_db: - mock_db.engine = Mock() - mock_session_cls.return_value.__enter__.return_value = mock_session - with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory - ): - validated = configuration.validate_provider_credentials(credentials={"region": "us"}) + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} - mock_session_cls.assert_called_once() def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None: @@ -717,18 +713,22 @@ def test_check_provider_credential_name_exists_and_model_setting_lookup() -> Non def test_validate_provider_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-key"} @@ -1060,37 +1060,35 @@ def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback( def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( encrypted_config='{"openai_api_key":"enc"}' ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) - assert validated == {"openai_api_key": "enc-new"} - - session = Mock() - mock_factory = Mock() - mock_factory.model_credentials_validate.return_value = {"region": "us"} - with _patched_session(session): + with _patched_session(mock_session): with patch( "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory ): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"region": "us"}, - ) + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) + assert validated == {"openai_api_key": "enc-new"} + + mock_factory2 = Mock() + mock_factory2.model_credentials_validate.return_value = {"region": "us"} + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"region": "us"}, + ) assert validated == {"region": "us"} @@ -1570,18 +1568,20 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = None + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = None mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_provider_credentials( - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} @@ -1692,20 +1692,24 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() - session = Mock() - session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_session = Mock() + mock_session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config="{invalid-json" + ) mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): - with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): - validated = configuration.validate_custom_model_credentials( - model_type=ModelType.LLM, - model="gpt-4o", - credentials={"openai_api_key": HIDDEN_VALUE}, - credential_id="cred-1", - session=session, - ) + with _patched_session(mock_session): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + ) assert validated == {"openai_api_key": "enc-new"} diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index 8685d16283..a159d3ad4d 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -1,5 +1,4 @@ import pytest -from graphon.model_runtime.entities.model_entities import ModelType from core.entities.parameter_entities import AppSelectorScope from core.entities.provider_entities import ( @@ -9,6 +8,7 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/external_data_tool/test_base.py b/api/tests/unit_tests/core/external_data_tool/test_base.py index 216cda83c5..63e887f904 100644 --- a/api/tests/unit_tests/core/external_data_tool/test_base.py +++ b/api/tests/unit_tests/core/external_data_tool/test_base.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from core.extension.extensible import ExtensionModule @@ -12,10 +14,10 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"}) @@ -28,10 +30,10 @@ class TestExternalDataTool: # Create a concrete subclass to test init class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return "" tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") @@ -43,10 +45,10 @@ class TestExternalDataTool: def test_validate_config_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): return super().validate_config(tenant_id, config) - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return "" with pytest.raises(NotImplementedError): @@ -55,10 +57,10 @@ class TestExternalDataTool: def test_query_raises_not_implemented(self): class ConcreteTool(ExternalDataTool): @classmethod - def validate_config(cls, tenant_id: str, config: dict): + def validate_config(cls, tenant_id: str, config: dict[str, Any]): pass - def query(self, inputs: dict, query: str | None = None) -> str: + def query(self, inputs: dict[str, Any], query: str | None = None) -> str: return super().query(inputs, query) tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") diff --git a/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py index 86b461cf04..c1c1291281 100644 --- a/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py +++ b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py @@ -13,7 +13,7 @@ class TestExternalDataFetch: app = Flask(__name__) return app - def test_fetch_success(self, app): + def test_fetch_success(self, app: Flask): with app.app_context(): fetcher = ExternalDataFetch() @@ -79,7 +79,7 @@ class TestExternalDataFetch: assert result_inputs == inputs assert result_inputs is not inputs # Should be a copy - def test_fetch_with_none_variable(self, app): + def test_fetch_with_none_variable(self, app: Flask): with app.app_context(): fetcher = ExternalDataFetch() tool = ExternalDataVariableEntity(variable="var1", type="type1", config={}) @@ -95,7 +95,7 @@ class TestExternalDataFetch: assert "var1" not in result_inputs assert result_inputs == {"in": "val"} - def test_query_external_data_tool(self, app): + def test_query_external_data_tool(self, app: Flask): fetcher = ExternalDataFetch() tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"}) diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index bb6e40e224..8cb0938575 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -3,9 +3,9 @@ from graphon.file import File, FileTransferMethod, FileType def test_file(): file = File( - id="test-file", + file_id="test-file", tenant_id="test-tenant-id", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="test-related-id", filename="image.png", @@ -25,27 +25,21 @@ def test_file(): assert file.size == 67 -def test_file_model_validate_accepts_legacy_tenant_id(): - data = { - "id": "test-file", - "tenant_id": "test-tenant-id", - "type": "image", - "transfer_method": "tool_file", - "related_id": "test-related-id", - "filename": "image.png", - "extension": ".png", - "mime_type": "image/png", - "size": 67, - "storage_key": "test-storage-key", - "url": "https://example.com/image.png", - # Extra legacy fields - "tool_file_id": "tool-file-123", - "upload_file_id": "upload-file-456", - "datasource_file_id": "datasource-file-789", - } +def test_file_constructor_accepts_legacy_tenant_id(): + file = File( + file_id="test-file", + tenant_id="test-tenant-id", + file_type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + tool_file_id="tool-file-123", + filename="image.png", + extension=".png", + mime_type="image/png", + size=67, + storage_key="test-storage-key", + url="https://example.com/image.png", + ) - file = File.model_validate(data) - - assert file.related_id == "test-related-id" + assert file.related_id == "tool-file-123" assert file.storage_key == "test-storage-key" assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/helper/test_creators.py b/api/tests/unit_tests/core/helper/test_creators.py new file mode 100644 index 0000000000..df67d3f513 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_creators.py @@ -0,0 +1,106 @@ +"""Tests for the Creators Platform helper module.""" + +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from yarl import URL + + +@pytest.fixture(autouse=True) +def _patch_creators_url(monkeypatch): + """Patch the module-level creators_platform_api_url for all tests.""" + monkeypatch.setattr( + "core.helper.creators.creators_platform_api_url", + URL("https://creators.example.com"), + ) + + +class TestUploadDSL: + @patch("core.helper.creators.httpx.post") + def test_returns_claim_code(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = {"data": {"claim_code": "abc123"}} + mock_response.raise_for_status = MagicMock() + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + result = upload_dsl(b"app: demo", "demo.yaml") + + assert result == "abc123" + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + assert "anonymous-upload" in call_kwargs.args[0] + assert call_kwargs.kwargs["timeout"] == 30 + + @patch("core.helper.creators.httpx.post") + def test_raises_on_missing_claim_code(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.json.return_value = {"data": {}} + mock_response.raise_for_status = MagicMock() + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + with pytest.raises(ValueError, match="claim_code"): + upload_dsl(b"app: demo") + + @patch("core.helper.creators.httpx.post") + def test_raises_on_http_error(self, mock_post): + mock_response = MagicMock(spec=httpx.Response) + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server Error", + request=MagicMock(), + response=MagicMock(), + ) + mock_post.return_value = mock_response + + from core.helper.creators import upload_dsl + + with pytest.raises(httpx.HTTPStatusError): + upload_dsl(b"app: demo") + + +class TestGetRedirectUrl: + @patch("core.helper.creators.dify_config") + def test_without_oauth_client_id(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "" + + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + assert "dsl_claim_code=claim-abc" in url + assert "oauth_code" not in url + assert url.startswith("https://creators.example.com") + + @patch("core.helper.creators.dify_config") + def test_with_oauth_client_id(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "client-xyz" + + with patch( + "services.oauth_server.OAuthServerService.sign_oauth_authorization_code", + return_value="oauth-code-123", + ) as mock_sign: + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + mock_sign.assert_called_once_with("client-xyz", "user-1") + assert "dsl_claim_code=claim-abc" in url + assert "oauth_code=oauth-code-123" in url + + @patch("core.helper.creators.dify_config") + def test_strips_trailing_slash(self, mock_config): + mock_config.CREATORS_PLATFORM_API_URL = "https://creators.example.com/" + mock_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID = "" + + from core.helper.creators import get_redirect_url + + url = get_redirect_url("user-1", "claim-abc") + + assert url.startswith("https://creators.example.com?") + assert "creators.example.com/?" not in url diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py index f3ef7fccd0..73e081a570 100644 --- a/api/tests/unit_tests/core/helper/test_encrypter.py +++ b/api/tests/unit_tests/core/helper/test_encrypter.py @@ -40,11 +40,11 @@ class TestObfuscatedToken: class TestEncryptToken: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_successful_encryption(self, mock_encrypt, mock_query): + def test_successful_encryption(self, mock_encrypt, mock_get): """Test successful token encryption""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_data" result = encrypt_token("tenant-123", "test_token") @@ -53,9 +53,9 @@ class TestEncryptToken: mock_encrypt.assert_called_with("test_token", "mock_public_key") @patch("extensions.ext_database.db.session.get") - def test_tenant_not_found(self, mock_query): + def test_tenant_not_found(self, mock_get): """Test error when tenant doesn't exist""" - mock_query.return_value = None + mock_get.return_value = None with pytest.raises(ValueError) as exc_info: encrypt_token("invalid-tenant", "test_token") @@ -122,12 +122,12 @@ class TestEncryptDecryptIntegration: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") @patch("libs.rsa.decrypt") - def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query): + def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_get): """Test that encryption and decryption are consistent""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # Setup mock encryption/decryption original_token = "test_token_123" @@ -148,12 +148,12 @@ class TestSecurity: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_cross_tenant_isolation(self, mock_encrypt, mock_query): + def test_cross_tenant_isolation(self, mock_encrypt, mock_get): """Ensure tokens encrypted for one tenant cannot be used by another""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "tenant1_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_for_tenant1" # Encrypt token for tenant1 @@ -183,10 +183,10 @@ class TestSecurity: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_encryption_randomness(self, mock_encrypt, mock_query): + def test_encryption_randomness(self, mock_encrypt, mock_get): """Ensure same plaintext produces different ciphertext""" mock_tenant = MagicMock(encrypt_public_key="key") - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # Different outputs for same input mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"] @@ -207,11 +207,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query): + def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_get): """Test encryption of empty token""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_empty" result = encrypt_token("tenant-123", "") @@ -221,11 +221,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query): + def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_get): """Test tokens containing special/unicode characters""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_special" # Test various special characters @@ -244,11 +244,11 @@ class TestEdgeCases: @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") - def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query): + def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_get): """Test behavior when token exceeds RSA encryption limits""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value = mock_tenant + mock_get.return_value = mock_tenant # RSA 2048-bit can only encrypt ~245 bytes # The actual limit depends on padding scheme diff --git a/api/tests/unit_tests/core/helper/test_moderation.py b/api/tests/unit_tests/core/helper/test_moderation.py index 4a84099b74..a0dfa86d20 100644 --- a/api/tests/unit_tests/core/helper/test_moderation.py +++ b/api/tests/unit_tests/core/helper/test_moderation.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from typing import cast import pytest -from graphon.model_runtime.errors.invoke import InvokeBadRequestError from pytest_mock import MockerFixture from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.helper.moderation import check_moderation +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from models.provider import ProviderType diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index 3b5c5e6597..d9fed9ae2a 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -1,11 +1,17 @@ from unittest.mock import MagicMock, patch +import httpx import pytest from core.helper.ssrf_proxy import ( SSRF_DEFAULT_MAX_RETRIES, + SSRFProxy, _get_user_provided_host_header, + _to_graphon_http_response, + graphon_ssrf_proxy, make_request, + max_retries_exceeded_error, + request_error, ) @@ -174,3 +180,56 @@ class TestFollowRedirectsParameter: call_kwargs = mock_client.request.call_args.kwargs assert call_kwargs.get("follow_redirects") is True + + +def test_to_graphon_http_response_preserves_httpx_response_fields() -> None: + response = httpx.Response( + 201, + headers={"X-Test": "1"}, + content=b"payload", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + wrapped = _to_graphon_http_response(response) + + assert wrapped.status_code == 201 + assert wrapped.headers == {"x-test": "1", "content-length": "7"} + assert wrapped.content == b"payload" + assert wrapped.url == "https://example.com/resource" + assert wrapped.reason_phrase == "Created" + assert wrapped.text == "payload" + + +def test_ssrf_proxy_exposes_expected_error_types() -> None: + proxy = SSRFProxy() + + assert proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert proxy.request_error is request_error + assert graphon_ssrf_proxy.max_retries_exceeded_error is max_retries_exceeded_error + assert graphon_ssrf_proxy.request_error is request_error + + +@pytest.mark.parametrize("method_name", ["get", "head", "post", "put", "delete", "patch"]) +def test_graphon_ssrf_proxy_wraps_module_requests(method_name: str) -> None: + response = httpx.Response( + 200, + headers={"X-Test": "1"}, + content=b"ok", + request=httpx.Request("GET", "https://example.com/resource"), + ) + + with patch(f"core.helper.ssrf_proxy.{method_name}", return_value=response) as mock_method: + wrapped = getattr(graphon_ssrf_proxy, method_name)( + "https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + + mock_method.assert_called_once_with( + url="https://example.com/resource", + max_retries=3, + headers={"X-Test": "1"}, + ) + assert wrapped.status_code == 200 + assert wrapped.url == "https://example.com/resource" + assert wrapped.content == b"ok" diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py index b45f6fd9a7..6ed9ddb476 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -2,20 +2,6 @@ import json from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultWithStructuredOutput, - LLMUsage, -) -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import ( @@ -30,6 +16,20 @@ from core.llm_generator.output_parser.structured_output import ( remove_additional_properties, ) from core.model_manager import ModelInstance +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType class TestStructuredOutput: diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 7cdfb31189..c4e610d5b0 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -2,12 +2,13 @@ import json from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError class TestLLMGenerator: @@ -96,6 +97,10 @@ class TestLLMGenerator: questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") assert len(questions) == 2 assert questions[0] == "Question 1?" + assert mock_model_instance.invoke_llm.call_args.kwargs["model_parameters"] == { + "max_tokens": 2560, + "temperature": 0.0, + } def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: @@ -113,6 +118,97 @@ class TestLLMGenerator: questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") assert questions == [] + @patch("core.llm_generator.llm_generator.ModelManager.for_tenant") + def test_generate_suggested_questions_after_answer_with_custom_model_and_prompt(self, mock_for_tenant): + custom_model_instance = MagicMock() + custom_response = MagicMock() + custom_response.message.get_text_content.return_value = '["Question 1?"]' + custom_model_instance.invoke_llm.return_value = custom_response + + mock_for_tenant.return_value.get_model_instance.return_value = custom_model_instance + + questions = LLMGenerator.generate_suggested_questions_after_answer( + "tenant_id", + "histories", + instruction_prompt="custom prompt", + model_config={ + "provider": "openai", + "name": "gpt-4o", + "completion_params": {"temperature": 0.2}, + }, + ) + + assert questions == ["Question 1?"] + mock_for_tenant.return_value.get_model_instance.assert_called_once_with( + tenant_id="tenant_id", + model_type=ModelType.LLM, + provider="openai", + model="gpt-4o", + ) + + invoke_kwargs = custom_model_instance.invoke_llm.call_args.kwargs + assert invoke_kwargs["model_parameters"] == {"temperature": 0.2} + assert invoke_kwargs["stop"] == [] + assert "custom prompt" in invoke_kwargs["prompt_messages"][0].content + + @patch("core.llm_generator.llm_generator.ModelManager.for_tenant") + def test_generate_suggested_questions_after_answer_fallback_to_default_model(self, mock_for_tenant): + default_model_instance = MagicMock() + default_response = MagicMock() + default_response.message.get_text_content.return_value = '["Question 1?"]' + default_model_instance.invoke_llm.return_value = default_response + + mock_for_tenant.return_value.get_model_instance.side_effect = ValueError("invalid configured model") + mock_for_tenant.return_value.get_default_model_instance.return_value = default_model_instance + + questions = LLMGenerator.generate_suggested_questions_after_answer( + "tenant_id", + "histories", + model_config={ + "provider": "openai", + "name": "not-found-model", + "completion_params": {"temperature": 0.2}, + }, + ) + + assert questions == ["Question 1?"] + mock_for_tenant.return_value.get_default_model_instance.assert_called_once_with( + tenant_id="tenant_id", + model_type=ModelType.LLM, + ) + assert default_model_instance.invoke_llm.call_args.kwargs["model_parameters"] == { + "max_tokens": 2560, + "temperature": 0.0, + } + assert default_model_instance.invoke_llm.call_args.kwargs["stop"] == [] + + @patch("core.llm_generator.llm_generator.ModelManager.for_tenant") + def test_generate_suggested_questions_after_answer_drops_non_positive_max_tokens(self, mock_for_tenant): + custom_model_instance = MagicMock() + custom_response = MagicMock() + custom_response.message.get_text_content.return_value = '["Question 1?"]' + custom_model_instance.invoke_llm.return_value = custom_response + mock_for_tenant.return_value.get_model_instance.return_value = custom_model_instance + + questions = LLMGenerator.generate_suggested_questions_after_answer( + "tenant_id", + "histories", + model_config={ + "provider": "openai", + "name": "gpt-4o", + "completion_params": { + "temperature": 0.2, + "max_tokens": 0, + "stop": ["END"], + }, + }, + ) + + assert questions == ["Question 1?"] + invoke_kwargs = custom_model_instance.invoke_llm.call_args.kwargs + assert invoke_kwargs["model_parameters"] == {"temperature": 0.2} + assert invoke_kwargs["stop"] == ["END"] + def test_generate_rule_config_no_variable_success(self, mock_model_instance, model_config_entity): payload = RuleGeneratePayload( instruction="test instruction", model_config=model_config_entity, no_variable=True @@ -395,7 +491,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}} @@ -421,7 +517,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() # Cause exception in node_type logic workflow.graph_dict = {"graph": {"nodes": []}} @@ -448,7 +544,7 @@ class TestLLMGenerator: def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity): with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} @@ -536,7 +632,7 @@ class TestLLMGenerator: instance.invoke_llm.return_value = mock_response with patch("extensions.ext_database.db.session") as mock_session: - mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + mock_session.return_value.scalar.return_value = MagicMock() workflow = MagicMock() workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}} diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py index 81f8da9a62..bbbffa2e69 100644 --- a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py @@ -971,6 +971,23 @@ class TestHandlePostRequestNew: assert isinstance(item, SessionMessage) assert isinstance(item.message.root, JSONRPCError) assert item.message.root.id == 77 + assert item.message.root.error.message == "Session terminated by server" + + def test_404_on_initialization_includes_url_in_error(self): + t = _new_transport(url="http://example.com/mcp/server/abc123/mcp") + q: queue.Queue = queue.Queue() + msg = _make_request_msg("initialize", 1) + ctx = self._make_ctx(t, q, message=msg) + mock_resp = MagicMock() + mock_resp.status_code = 404 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert isinstance(item.message.root, JSONRPCError) + assert item.message.root.error.code == 32600 + assert "404 Not Found" in item.message.root.error.message + assert "http://example.com/mcp/server/abc123/mcp" in item.message.root.error.message def test_404_for_notification_no_error_sent(self): t = _new_transport() diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index 9a815fb94d..57456085c3 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -3,7 +3,6 @@ from unittest.mock import Mock, patch import jsonschema import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types @@ -19,6 +18,7 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) +from graphon.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index 9a5fb319d7..f459250b8e 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -4,6 +4,8 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest + +from core.memory.token_buffer_memory import TokenBufferMemory from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -11,8 +13,6 @@ from graphon.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from core.memory.token_buffer_memory import TokenBufferMemory from models.model import AppMode # --------------------------------------------------------------------------- diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index 6a672fdfd5..c4fd970562 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -1,6 +1,7 @@ from unittest.mock import Mock import pytest + from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -12,12 +13,12 @@ from graphon.model_runtime.entities.provider_entities import ( ProviderCredentialSchema, ProviderEntity, ) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel -from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index 3a97ad5c5d..4c668ee96b 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -10,6 +10,7 @@ This module tests all aspects of the content moderation system including: - Configuration validation """ +from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest @@ -28,7 +29,7 @@ class TestKeywordsModeration: """Test suite for custom keyword-based content moderation.""" @pytest.fixture - def keywords_config(self) -> dict: + def keywords_config(self) -> dict[str, Any]: """ Fixture providing a standard keywords moderation configuration. @@ -48,7 +49,7 @@ class TestKeywordsModeration: } @pytest.fixture - def keywords_moderation(self, keywords_config: dict) -> KeywordsModeration: + def keywords_moderation(self, keywords_config: dict[str, Any]) -> KeywordsModeration: """ Fixture providing a KeywordsModeration instance. @@ -64,7 +65,7 @@ class TestKeywordsModeration: config=keywords_config, ) - def test_validate_config_success(self, keywords_config: dict): + def test_validate_config_success(self, keywords_config: dict[str, Any]): """Test successful validation of keywords moderation configuration.""" # Should not raise any exception KeywordsModeration.validate_config("test-tenant", keywords_config) @@ -274,7 +275,7 @@ class TestOpenAIModeration: """Test suite for OpenAI-based content moderation.""" @pytest.fixture - def openai_config(self) -> dict: + def openai_config(self) -> dict[str, Any]: """ Fixture providing OpenAI moderation configuration. @@ -293,7 +294,7 @@ class TestOpenAIModeration: } @pytest.fixture - def openai_moderation(self, openai_config: dict) -> OpenAIModeration: + def openai_moderation(self, openai_config: dict[str, Any]) -> OpenAIModeration: """ Fixture providing an OpenAIModeration instance. @@ -309,7 +310,7 @@ class TestOpenAIModeration: config=openai_config, ) - def test_validate_config_success(self, openai_config: dict): + def test_validate_config_success(self, openai_config: dict[str, Any]): """Test successful validation of OpenAI moderation configuration.""" # Should not raise any exception OpenAIModeration.validate_config("test-tenant", openai_config) diff --git a/api/tests/unit_tests/core/ops/test_config_entity.py b/api/tests/unit_tests/core/ops/test_config_entity.py index 2cbff54c42..69650c85cc 100644 --- a/api/tests/unit_tests/core/ops/test_config_entity.py +++ b/api/tests/unit_tests/core/ops/test_config_entity.py @@ -1,16 +1,11 @@ -import pytest -from pydantic import ValidationError +from dify_trace_aliyun.config import AliyunConfig +from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from dify_trace_langfuse.config import LangfuseConfig +from dify_trace_langsmith.config import LangSmithConfig +from dify_trace_opik.config import OpikConfig +from dify_trace_weave.config import WeaveConfig -from core.ops.entities.config_entity import ( - AliyunConfig, - ArizeConfig, - LangfuseConfig, - LangSmithConfig, - OpikConfig, - PhoenixConfig, - TracingProviderEnum, - WeaveConfig, -) +from core.ops.entities.config_entity import TracingProviderEnum class TestTracingProviderEnum: @@ -27,349 +22,8 @@ class TestTracingProviderEnum: assert TracingProviderEnum.ALIYUN == "aliyun" -class TestArizeConfig: - """Test cases for ArizeConfig""" - - def test_valid_config(self): - """Test valid Arize configuration""" - config = ArizeConfig( - api_key="test_key", space_id="test_space", project="test_project", endpoint="https://custom.arize.com" - ) - assert config.api_key == "test_key" - assert config.space_id == "test_space" - assert config.project == "test_project" - assert config.endpoint == "https://custom.arize.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = ArizeConfig() - assert config.api_key is None - assert config.space_id is None - assert config.project is None - assert config.endpoint == "https://otlp.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = ArizeConfig(project="") - assert config.project == "default" - - def test_project_validation_none(self): - """Test project validation with None value""" - config = ArizeConfig(project=None) - assert config.project == "default" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = ArizeConfig(endpoint="") - assert config.endpoint == "https://otlp.arize.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation normalizes URL by removing path""" - config = ArizeConfig(endpoint="https://custom.arize.com/api/v1") - assert config.endpoint == "https://custom.arize.com" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="ftp://invalid.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - ArizeConfig(endpoint="invalid.com") - - -class TestPhoenixConfig: - """Test cases for PhoenixConfig""" - - def test_valid_config(self): - """Test valid Phoenix configuration""" - config = PhoenixConfig(api_key="test_key", project="test_project", endpoint="https://custom.phoenix.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.phoenix.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = PhoenixConfig() - assert config.api_key is None - assert config.project is None - assert config.endpoint == "https://app.phoenix.arize.com" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = PhoenixConfig(project="") - assert config.project == "default" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation with path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") - assert config.endpoint == "https://app.phoenix.arize.com/s/dify-integration" - - def test_endpoint_validation_without_path(self): - """Test endpoint validation without path""" - config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") - assert config.endpoint == "https://app.phoenix.arize.com" - - -class TestLangfuseConfig: - """Test cases for LangfuseConfig""" - - def test_valid_config(self): - """Test valid Langfuse configuration""" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host="https://custom.langfuse.com") - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == "https://custom.langfuse.com" - - def test_valid_config_with_path(self): - host = "https://custom.langfuse.com/api/v1" - config = LangfuseConfig(public_key="public_key", secret_key="secret_key", host=host) - assert config.public_key == "public_key" - assert config.secret_key == "secret_key" - assert config.host == host - - def test_default_values(self): - """Test default values are set correctly""" - config = LangfuseConfig(public_key="public", secret_key="secret") - assert config.host == "https://api.langfuse.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangfuseConfig() - - with pytest.raises(ValidationError): - LangfuseConfig(public_key="public") - - with pytest.raises(ValidationError): - LangfuseConfig(secret_key="secret") - - def test_host_validation_empty(self): - """Test host validation with empty value""" - config = LangfuseConfig(public_key="public", secret_key="secret", host="") - assert config.host == "https://api.langfuse.com" - - -class TestLangSmithConfig: - """Test cases for LangSmithConfig""" - - def test_valid_config(self): - """Test valid LangSmith configuration""" - config = LangSmithConfig(api_key="test_key", project="test_project", endpoint="https://custom.smith.com") - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.endpoint == "https://custom.smith.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = LangSmithConfig(api_key="key", project="project") - assert config.endpoint == "https://api.smith.langchain.com" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - LangSmithConfig() - - with pytest.raises(ValidationError): - LangSmithConfig(api_key="key") - - with pytest.raises(ValidationError): - LangSmithConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - LangSmithConfig(api_key="key", project="project", endpoint="http://insecure.com") - - -class TestOpikConfig: - """Test cases for OpikConfig""" - - def test_valid_config(self): - """Test valid Opik configuration""" - config = OpikConfig( - api_key="test_key", - project="test_project", - workspace="test_workspace", - url="https://custom.comet.com/opik/api/", - ) - assert config.api_key == "test_key" - assert config.project == "test_project" - assert config.workspace == "test_workspace" - assert config.url == "https://custom.comet.com/opik/api/" - - def test_default_values(self): - """Test default values are set correctly""" - config = OpikConfig() - assert config.api_key is None - assert config.project is None - assert config.workspace is None - assert config.url == "https://www.comet.com/opik/api/" - - def test_project_validation_empty(self): - """Test project validation with empty value""" - config = OpikConfig(project="") - assert config.project == "Default Project" - - def test_url_validation_empty(self): - """Test URL validation with empty value""" - config = OpikConfig(url="") - assert config.url == "https://www.comet.com/opik/api/" - - def test_url_validation_missing_suffix(self): - """Test URL validation requires /api/ suffix""" - with pytest.raises(ValidationError, match="URL should end with /api/"): - OpikConfig(url="https://custom.comet.com/opik/") - - def test_url_validation_invalid_scheme(self): - """Test URL validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - OpikConfig(url="ftp://custom.comet.com/opik/api/") - - -class TestWeaveConfig: - """Test cases for WeaveConfig""" - - def test_valid_config(self): - """Test valid Weave configuration""" - config = WeaveConfig( - api_key="test_key", - entity="test_entity", - project="test_project", - endpoint="https://custom.wandb.ai", - host="https://custom.host.com", - ) - assert config.api_key == "test_key" - assert config.entity == "test_entity" - assert config.project == "test_project" - assert config.endpoint == "https://custom.wandb.ai" - assert config.host == "https://custom.host.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = WeaveConfig(api_key="key", project="project") - assert config.entity is None - assert config.endpoint == "https://trace.wandb.ai" - assert config.host is None - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - WeaveConfig() - - with pytest.raises(ValidationError): - WeaveConfig(api_key="key") - - with pytest.raises(ValidationError): - WeaveConfig(project="project") - - def test_endpoint_validation_https_only(self): - """Test endpoint validation only allows HTTPS""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai") - - def test_host_validation_optional(self): - """Test host validation is optional but validates when provided""" - config = WeaveConfig(api_key="key", project="project", host=None) - assert config.host is None - - config = WeaveConfig(api_key="key", project="project", host="") - assert config.host == "" - - config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com") - assert config.host == "https://valid.host.com" - - def test_host_validation_invalid_scheme(self): - """Test host validation rejects invalid schemes when provided""" - with pytest.raises(ValidationError, match="URL scheme must be one of"): - WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com") - - -class TestAliyunConfig: - """Test cases for AliyunConfig""" - - def test_valid_config(self): - """Test valid Aliyun configuration""" - config = AliyunConfig( - app_name="test_app", - license_key="test_license_key", - endpoint="https://custom.tracing-analysis-dc-hz.aliyuncs.com", - ) - assert config.app_name == "test_app" - assert config.license_key == "test_license_key" - assert config.endpoint == "https://custom.tracing-analysis-dc-hz.aliyuncs.com" - - def test_default_values(self): - """Test default values are set correctly""" - config = AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - assert config.app_name == "dify_app" - - def test_missing_required_fields(self): - """Test that required fields are enforced""" - with pytest.raises(ValidationError): - AliyunConfig() - - with pytest.raises(ValidationError): - AliyunConfig(license_key="test_license") - - with pytest.raises(ValidationError): - AliyunConfig(endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_app_name_validation_empty(self): - """Test app_name validation with empty value""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com", app_name="" - ) - assert config.app_name == "dify_app" - - def test_endpoint_validation_empty(self): - """Test endpoint validation with empty value""" - config = AliyunConfig(license_key="test_license", endpoint="") - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com" - - def test_endpoint_validation_with_path(self): - """Test endpoint validation preserves path for Aliyun endpoints""" - config = AliyunConfig( - license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - ) - assert config.endpoint == "https://tracing-analysis-dc-hz.aliyuncs.com/api/v1/traces" - - def test_endpoint_validation_invalid_scheme(self): - """Test endpoint validation rejects invalid schemes""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="ftp://invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_endpoint_validation_no_scheme(self): - """Test endpoint validation rejects URLs without scheme""" - with pytest.raises(ValidationError, match="URL must start with https:// or http://"): - AliyunConfig(license_key="test_license", endpoint="invalid.tracing-analysis-dc-hz.aliyuncs.com") - - def test_license_key_required(self): - """Test that license_key is required and cannot be empty""" - with pytest.raises(ValidationError): - AliyunConfig(license_key="", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com") - - def test_valid_endpoint_format_examples(self): - """Test valid endpoint format examples from comments""" - valid_endpoints = [ - # cms2.0 public endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan.log.aliyuncs.com/apm/trace/opentelemetry", - # cms2.0 intranet endpoint - "https://proj-xtrace-123456-cn-heyuan.cn-heyuan-intranet.log.aliyuncs.com/apm/trace/opentelemetry", - # xtrace public endpoint - "http://tracing-cn-heyuan.arms.aliyuncs.com", - # xtrace intranet endpoint - "http://tracing-cn-heyuan-internal.arms.aliyuncs.com", - ] - - for endpoint in valid_endpoints: - config = AliyunConfig(license_key="test_license", endpoint=endpoint) - assert config.endpoint == endpoint - - class TestConfigIntegration: - """Integration tests for configuration classes""" + """Cross-provider configuration sanity checks""" def test_all_configs_can_be_instantiated(self): """Test that all config classes can be instantiated with valid data""" @@ -388,7 +42,6 @@ class TestConfigIntegration: def test_url_normalization_consistency(self): """Test that URL normalization works consistently across configs""" - # Test that paths are removed from endpoints arize_config = ArizeConfig(endpoint="https://arize.com/api/v1/test") phoenix_with_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com/s/dify-integration") phoenix_without_path_config = PhoenixConfig(endpoint="https://app.phoenix.arize.com") diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py index 543b278715..c24d3ac012 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -1,10 +1,9 @@ from types import SimpleNamespace from unittest.mock import patch -from graphon.model_runtime.entities.message_entities import UserPromptMessage - from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.entities.request import RequestInvokeSummary +from graphon.model_runtime.entities.message_entities import UserPromptMessage def test_system_model_helpers_forward_user_id(): diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index f8d0e127b1..88bf555594 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -6,15 +6,15 @@ from types import SimpleNamespace from unittest.mock import Mock, sentinel import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl import model_runtime as model_runtime_module from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_schema() -> AIModelEntity: @@ -56,7 +56,7 @@ class TestPluginModelRuntime: assert len(providers) == 1 assert providers[0].provider == "langgenius/openai/openai" assert providers[0].provider_name == "openai" - assert providers[0].label.en_US == "OpenAI" + assert providers[0].label.en_us == "OpenAI" client.fetch_model_providers.assert_called_once_with("tenant") def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py index a812b01c5b..f1c4c7e700 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_entities.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -4,12 +4,6 @@ from enum import StrEnum import pytest from flask import Response -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) from pydantic import ValidationError from core.plugin.entities.endpoint import EndpointEntityWithInstance @@ -31,6 +25,12 @@ from core.plugin.entities.request import ( ) from core.plugin.utils.http_parser import serialize_response from core.tools.entities.common_entities import I18nObject +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) class TestEndpointEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index a3b1e5f6b0..704b82adc0 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -17,14 +17,6 @@ from unittest.mock import MagicMock, patch import httpx import pytest -from graphon.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from core.plugin.entities.plugin_daemon import ( @@ -45,6 +37,14 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError @pytest.fixture(autouse=True) diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index 90730dff5a..00a4207786 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -1,12 +1,12 @@ from collections.abc import Generator import pytest -from graphon.file import File, FileTransferMethod, FileType from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector +from graphon.file import File, FileTransferMethod, FileType class TestChunkMerger: @@ -466,7 +466,7 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self): file_param = File( tenant_id="tenant-1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/file.png", storage_key="", @@ -499,14 +499,14 @@ class TestConverter: def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self): file_one = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/a.txt", storage_key="", ) file_two = File( tenant_id="tenant-1", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/b.txt", storage_key="", diff --git a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py index 71144695bc..e0419d3266 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py +++ b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py @@ -323,6 +323,50 @@ class TestDeserializeResponse: with pytest.raises(ValueError, match="Invalid status line"): deserialize_response(raw_data) + def test_deserialize_response_preserves_duplicate_set_cookie_headers(self): + # Regression test for https://github.com/langgenius/dify/issues/35722 + # Multiple Set-Cookie headers must be preserved per RFC 9110, not collapsed + # into a single value by dict-style assignment. + raw_data = ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Set-Cookie: session=abc; Path=/; HttpOnly\r\n" + b"Set-Cookie: tracking=xyz; Path=/; Secure\r\n" + b"\r\n" + b"ok" + ) + + response = deserialize_response(raw_data) + + cookies = response.headers.getlist("Set-Cookie") + assert cookies == [ + "session=abc; Path=/; HttpOnly", + "tracking=xyz; Path=/; Secure", + ] + # Single-valued headers should still be readable normally. + assert response.headers.get("Content-Type") == "text/plain" + + def test_deserialize_response_preserves_duplicate_generic_headers(self): + # Any header name (not just Set-Cookie) may legitimately repeat; verify the + # parser preserves all values rather than overwriting earlier ones. + raw_data = b"HTTP/1.1 200 OK\r\nX-Custom: first\r\nX-Custom: second\r\n\r\n" + + response = deserialize_response(raw_data) + + assert response.headers.getlist("X-Custom") == ["first", "second"] + + def test_deserialize_response_does_not_inject_default_content_type(self): + # Flask's Response constructor adds a default Content-Type header. When the + # raw response has no Content-Type, the parsed response should not silently + # gain one from the framework default. + raw_data = b"HTTP/1.1 204 No Content\r\nX-Trace-Id: abc\r\n\r\n" + + response = deserialize_response(raw_data) + + header_names = [name for name, _ in response.headers.items()] + assert "Content-Type" not in header_names + assert response.headers.get("X-Trace-Id") == "abc" + def test_roundtrip_response(self): # Test that serialize -> deserialize produces equivalent response original_response = Response( diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 2b280dd674..e536c0831f 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,6 +2,13 @@ from typing import cast from unittest.mock import MagicMock, patch import pytest + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -11,13 +18,6 @@ from graphon.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) - -from configs import dify_config -from core.app.app_config.entities import ModelConfigEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation @@ -134,9 +134,9 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg files = [ File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="", @@ -245,9 +245,9 @@ def test_completion_prompt_jinja2_with_files(): completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2") file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -379,9 +379,9 @@ def test_chat_prompt_memory_with_files_and_query(): memory = MagicMock(spec=TokenBufferMemory) prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -413,9 +413,9 @@ def test_chat_prompt_files_without_query_updates_last_user_or_appends_new(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", @@ -463,9 +463,9 @@ def test_chat_prompt_files_with_query_branch(): transform = AdvancedPromptTransform() model_config_mock = MagicMock(spec=ModelConfigEntity) file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image.jpg", storage_key="", diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 4a54649b28..28966242d8 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,19 +1,18 @@ from unittest.mock import MagicMock -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index a4b3960b0a..5d865d934c 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,3 +1,5 @@ +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, @@ -7,9 +9,6 @@ from graphon.model_runtime.entities.message_entities import ( UserPromptMessage, ) -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil - def test_build_prompt_message_with_prompt_message_contents(): prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index e35ce2c48a..5308c8e7b3 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -2,16 +2,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.prompt.prompt_transform import PromptTransform +from graphon.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle # from graphon.model_runtime.entities.message_entities import UserPromptMessage # from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule # from graphon.model_runtime.entities.provider_entities import ProviderEntity -# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index 3f188cfbb4..0dc74b33df 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -2,12 +2,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - TextPromptMessageContent, - UserPromptMessage, -) from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -24,6 +18,12 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( CONTEXT, ) from core.prompt.simple_prompt_transform import SimplePromptTransform +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from models.model import AppMode, Conversation diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py index 006b4e7345..1f3247590c 100644 --- a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -1,13 +1,12 @@ from unittest.mock import MagicMock, patch -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.errors.invoke import InvokeAuthorizationError - from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError def _doc(content: str) -> Document: diff --git a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py index bbdd476914..1e91c2dd88 100644 --- a/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py +++ b/api/tests/unit_tests/core/rag/datasource/keyword/jieba/test_jieba.py @@ -1,5 +1,6 @@ import json from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock import pytest @@ -28,15 +29,6 @@ class _Field: return ("in", self._name, tuple(values)) -class _FakeQuery: - def __init__(self): - self.where_calls: list[tuple] = [] - - def where(self, *conditions): - self.where_calls.append(conditions) - return self - - class _FakeExecuteResult: def __init__(self, segments: list[SimpleNamespace]): self._segments = segments @@ -57,7 +49,7 @@ class _FakeSelect: return self -def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict | None = None): +def _dataset_keyword_table(data_source_type: str = "database", keyword_table_dict: dict[str, Any] | None = None): return SimpleNamespace( data_source_type=data_source_type, keyword_table_dict=keyword_table_dict, diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index 8b104597a8..b0ecad4d0c 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, call, patch from uuid import uuid4 @@ -20,7 +21,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -108,17 +109,6 @@ class _FakeExecuteResult: return _FakeExecuteScalarResult(self._data) -class _FakeSummaryQuery: - def __init__(self, summaries: list) -> None: - self._summaries = summaries - - def filter(self, *args, **kwargs): - return self - - def all(self) -> list: - return self._summaries - - class _FakeScalarsResult: def __init__(self, data: list) -> None: self._data = data diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/pgvector/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py index 4e9ceddda9..7b6ee97f1c 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -21,6 +21,9 @@ def _register_fake_factory_module(monkeypatch, module_path: str, class_name: str def vector_factory_module(): import importlib + from core.rag.datasource.vdb import vector_backend_registry as reg + + reg.clear_vector_factory_cache() import core.rag.datasource.vdb.vector_factory as module return importlib.reload(module) @@ -41,61 +44,62 @@ def test_gen_index_struct_dict(vector_factory_module): @pytest.mark.parametrize( ("vector_type", "module_path", "class_name"), [ - ("CHROMA", "core.rag.datasource.vdb.chroma.chroma_vector", "ChromaVectorFactory"), - ("MILVUS", "core.rag.datasource.vdb.milvus.milvus_vector", "MilvusVectorFactory"), + ("CHROMA", "dify_vdb_chroma.chroma_vector", "ChromaVectorFactory"), + ("MILVUS", "dify_vdb_milvus.milvus_vector", "MilvusVectorFactory"), ( "ALIBABACLOUD_MYSQL", - "core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector", + "dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector", "AlibabaCloudMySQLVectorFactory", ), - ("MYSCALE", "core.rag.datasource.vdb.myscale.myscale_vector", "MyScaleVectorFactory"), - ("PGVECTOR", "core.rag.datasource.vdb.pgvector.pgvector", "PGVectorFactory"), - ("VASTBASE", "core.rag.datasource.vdb.pyvastbase.vastbase_vector", "VastbaseVectorFactory"), - ("PGVECTO_RS", "core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"), - ("QDRANT", "core.rag.datasource.vdb.qdrant.qdrant_vector", "QdrantVectorFactory"), - ("RELYT", "core.rag.datasource.vdb.relyt.relyt_vector", "RelytVectorFactory"), + ("MYSCALE", "dify_vdb_myscale.myscale_vector", "MyScaleVectorFactory"), + ("PGVECTOR", "dify_vdb_pgvector.pgvector", "PGVectorFactory"), + ("VASTBASE", "dify_vdb_vastbase.vastbase_vector", "VastbaseVectorFactory"), + ("PGVECTO_RS", "dify_vdb_pgvecto_rs.pgvecto_rs", "PGVectoRSFactory"), + ("QDRANT", "dify_vdb_qdrant.qdrant_vector", "QdrantVectorFactory"), + ("RELYT", "dify_vdb_relyt.relyt_vector", "RelytVectorFactory"), ( "ELASTICSEARCH", - "core.rag.datasource.vdb.elasticsearch.elasticsearch_vector", + "dify_vdb_elasticsearch.elasticsearch_vector", "ElasticSearchVectorFactory", ), ( "ELASTICSEARCH_JA", - "core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector", + "dify_vdb_elasticsearch.elasticsearch_ja_vector", "ElasticSearchJaVectorFactory", ), - ("TIDB_VECTOR", "core.rag.datasource.vdb.tidb_vector.tidb_vector", "TiDBVectorFactory"), - ("WEAVIATE", "core.rag.datasource.vdb.weaviate.weaviate_vector", "WeaviateVectorFactory"), - ("TENCENT", "core.rag.datasource.vdb.tencent.tencent_vector", "TencentVectorFactory"), - ("ORACLE", "core.rag.datasource.vdb.oracle.oraclevector", "OracleVectorFactory"), + ("TIDB_VECTOR", "dify_vdb_tidb_vector.tidb_vector", "TiDBVectorFactory"), + ("WEAVIATE", "dify_vdb_weaviate.weaviate_vector", "WeaviateVectorFactory"), + ("TENCENT", "dify_vdb_tencent.tencent_vector", "TencentVectorFactory"), + ("ORACLE", "dify_vdb_oracle.oraclevector", "OracleVectorFactory"), ( "OPENSEARCH", - "core.rag.datasource.vdb.opensearch.opensearch_vector", + "dify_vdb_opensearch.opensearch_vector", "OpenSearchVectorFactory", ), - ("ANALYTICDB", "core.rag.datasource.vdb.analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"), - ("COUCHBASE", "core.rag.datasource.vdb.couchbase.couchbase_vector", "CouchbaseVectorFactory"), - ("BAIDU", "core.rag.datasource.vdb.baidu.baidu_vector", "BaiduVectorFactory"), - ("VIKINGDB", "core.rag.datasource.vdb.vikingdb.vikingdb_vector", "VikingDBVectorFactory"), - ("UPSTASH", "core.rag.datasource.vdb.upstash.upstash_vector", "UpstashVectorFactory"), + ("ANALYTICDB", "dify_vdb_analyticdb.analyticdb_vector", "AnalyticdbVectorFactory"), + ("COUCHBASE", "dify_vdb_couchbase.couchbase_vector", "CouchbaseVectorFactory"), + ("BAIDU", "dify_vdb_baidu.baidu_vector", "BaiduVectorFactory"), + ("VIKINGDB", "dify_vdb_vikingdb.vikingdb_vector", "VikingDBVectorFactory"), + ("UPSTASH", "dify_vdb_upstash.upstash_vector", "UpstashVectorFactory"), ( "TIDB_ON_QDRANT", - "core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector", + "dify_vdb_tidb_on_qdrant.tidb_on_qdrant_vector", "TidbOnQdrantVectorFactory", ), - ("LINDORM", "core.rag.datasource.vdb.lindorm.lindorm_vector", "LindormVectorStoreFactory"), - ("OCEANBASE", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), - ("SEEKDB", "core.rag.datasource.vdb.oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), - ("OPENGAUSS", "core.rag.datasource.vdb.opengauss.opengauss", "OpenGaussFactory"), - ("TABLESTORE", "core.rag.datasource.vdb.tablestore.tablestore_vector", "TableStoreVectorFactory"), + ("LINDORM", "dify_vdb_lindorm.lindorm_vector", "LindormVectorStoreFactory"), + ("OCEANBASE", "dify_vdb_oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("SEEKDB", "dify_vdb_oceanbase.oceanbase_vector", "OceanBaseVectorFactory"), + ("OPENGAUSS", "dify_vdb_opengauss.opengauss", "OpenGaussFactory"), + ("TABLESTORE", "dify_vdb_tablestore.tablestore_vector", "TableStoreVectorFactory"), ( "HUAWEI_CLOUD", - "core.rag.datasource.vdb.huawei.huawei_cloud_vector", + "dify_vdb_huawei_cloud.huawei_cloud_vector", "HuaweiCloudVectorFactory", ), - ("MATRIXONE", "core.rag.datasource.vdb.matrixone.matrixone_vector", "MatrixoneVectorFactory"), - ("CLICKZETTA", "core.rag.datasource.vdb.clickzetta.clickzetta_vector", "ClickzettaVectorFactory"), - ("IRIS", "core.rag.datasource.vdb.iris.iris_vector", "IrisVectorFactory"), + ("MATRIXONE", "dify_vdb_matrixone.matrixone_vector", "MatrixoneVectorFactory"), + ("CLICKZETTA", "dify_vdb_clickzetta.clickzetta_vector", "ClickzettaVectorFactory"), + ("IRIS", "dify_vdb_iris.iris_vector", "IrisVectorFactory"), + ("HOLOGRES", "dify_vdb_hologres.hologres_vector", "HologresVectorFactory"), ], ) def test_get_vector_factory_supported(vector_factory_module, monkeypatch, vector_type, module_path, class_name): @@ -111,22 +115,105 @@ def test_get_vector_factory_unsupported(vector_factory_module): vector_factory_module.Vector.get_vector_factory("unknown") +class _PluginChromaFactory: + """Stub used only for entry-point override test.""" + + +def test_get_vector_factory_entry_point_overrides_builtin(vector_factory_module, monkeypatch): + from importlib.metadata import EntryPoint + + from core.rag.datasource.vdb import vector_backend_registry as reg + + reg.clear_vector_factory_cache() + ep = EntryPoint( + name="chroma", + value=f"{__name__}:_PluginChromaFactory", + group="dify.vector_backends", + ) + + class _FakeGroups: + def select(self, *, group: str): + if group == "dify.vector_backends": + return (ep,) + return () + + monkeypatch.setattr(reg, "entry_points", lambda: _FakeGroups()) + + result_cls = vector_factory_module.Vector.get_vector_factory(vector_factory_module.VectorType.CHROMA) + assert result_cls is _PluginChromaFactory + + def test_vector_init_uses_default_and_custom_attributes(vector_factory_module): dataset = SimpleNamespace(id="dataset-1") - with ( - patch.object(vector_factory_module.Vector, "_get_embeddings", return_value="embeddings"), - patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"), - ): + with patch.object(vector_factory_module.Vector, "_init_vector", return_value="processor"): default_vector = vector_factory_module.Vector(dataset) custom_vector = vector_factory_module.Vector(dataset, attributes=["doc_id"]) - assert default_vector._attributes == ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + # `is_summary` and `original_chunk_id` must be in the default return-properties + # projection so summary index retrieval works on backends that honor the list + # as an explicit projection (e.g. Weaviate). See #34884. + assert default_vector._attributes == [ + "doc_id", + "dataset_id", + "document_id", + "doc_hash", + "doc_type", + "is_summary", + "original_chunk_id", + ] assert custom_vector._attributes == ["doc_id"] - assert default_vector._embeddings == "embeddings" + # ``_embeddings`` is now a lazy proxy that defers materializing the real + # embedding model until ``embed_*`` is invoked, so cleanup paths never + # trigger billing/feature-service calls during ``Vector(dataset)`` + # construction. See ``_LazyEmbeddings``. + assert isinstance(default_vector._embeddings, vector_factory_module._LazyEmbeddings) assert default_vector._vector_processor == "processor" +def test_lazy_embeddings_defer_real_load_until_first_embed_call(vector_factory_module, monkeypatch): + """``Vector(dataset)`` must not transitively call ``ModelManager`` during + construction. The real embedding model should only be materialized on the + first ``embed_*`` call (i.e. create / search paths) so cleanup paths + (``delete_by_ids`` / ``delete``) remain resilient to billing-API failures. + """ + for_tenant_mock = MagicMock(side_effect=AssertionError("ModelManager.for_tenant must not be called eagerly")) + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock) + + dataset = SimpleNamespace( + tenant_id="tenant-1", + embedding_model_provider="openai", + embedding_model="text-embedding-3-small", + ) + + proxy = vector_factory_module._LazyEmbeddings(dataset) + + # Construction alone does not trigger ModelManager / FeatureService / BillingService. + for_tenant_mock.assert_not_called() + + # Exercising an embed_* method materializes the real model exactly once. + inner_model = MagicMock() + inner_model.embed_documents.return_value = [[0.1, 0.2]] + cached_embedding_mock = MagicMock(return_value=inner_model) + real_for_tenant = MagicMock() + real_for_tenant.get_model_instance.return_value = "embedding-model-instance" + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", MagicMock(return_value=real_for_tenant)) + monkeypatch.setattr(vector_factory_module, "CacheEmbedding", cached_embedding_mock) + + result = proxy.embed_documents(["hello"]) + + assert result == [[0.1, 0.2]] + cached_embedding_mock.assert_called_once_with("embedding-model-instance") + inner_model.embed_documents.assert_called_once_with(["hello"]) + + # Subsequent calls reuse the materialized model (no re-resolution). + inner_model.embed_documents.reset_mock() + cached_embedding_mock.reset_mock() + proxy.embed_documents(["world"]) + cached_embedding_mock.assert_not_called() + inner_model.embed_documents.assert_called_once_with(["world"]) + + def test_init_vector_prefers_dataset_index_struct(vector_factory_module, monkeypatch): calls = {"vector_type": None, "init_args": None} @@ -229,6 +316,33 @@ def test_create_batches_texts_and_skips_empty_input(vector_factory_module): vector._vector_processor.create.assert_not_called() +def test_create_skips_empty_text_documents_before_embedding(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_documents.return_value = [[0.1], [0.2]] + vector._vector_processor = MagicMock() + + docs = [ + Document(page_content="foo", metadata={"doc_id": "id-1"}), + Document(page_content="", metadata={"doc_id": "id-empty"}), + Document(page_content=" \n", metadata={"doc_id": "id-blank"}), + Document(page_content="bar", metadata={"doc_id": "id-2"}), + ] + + vector.create(texts=docs, request_id="r-1") + + vector._embeddings.embed_documents.assert_called_once_with(["foo", "bar"]) + vector._vector_processor.create.assert_called_once_with( + texts=[docs[0], docs[3]], embeddings=[[0.1], [0.2]], request_id="r-1" + ) + + vector._embeddings.embed_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.create(texts=[docs[1], docs[2]]) + vector._embeddings.embed_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + def test_create_multimodal_filters_missing_uploads(vector_factory_module, monkeypatch): class _Field: def in_(self, value): @@ -309,6 +423,48 @@ def test_add_texts_with_optional_duplicate_check(vector_factory_module): vector._vector_processor.create.assert_called_once() +def test_add_texts_skips_empty_text_documents(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_documents.return_value = [[0.1]] + vector._vector_processor = MagicMock() + + docs = [ + Document(page_content="keep", metadata={"doc_id": "id-1"}), + Document(page_content="", metadata={"doc_id": "id-empty"}), + ] + + vector.add_texts(docs, source="api") + + vector._embeddings.embed_documents.assert_called_once_with(["keep"]) + vector._vector_processor.create.assert_called_once_with(texts=[docs[0]], embeddings=[[0.1]], source="api") + + vector._embeddings.embed_documents.reset_mock() + vector._vector_processor.create.reset_mock() + vector.add_texts([docs[1]]) + vector._embeddings.embed_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + +def test_add_texts_filters_empty_documents_before_duplicate_check(vector_factory_module): + vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) + vector._embeddings = MagicMock() + vector._embeddings.embed_documents.return_value = [[0.1]] + vector._vector_processor = MagicMock() + vector._filter_duplicate_texts = MagicMock(return_value=[]) + + docs = [ + Document(page_content="keep", metadata={"doc_id": "id-1"}), + Document(page_content=" ", metadata={"doc_id": "id-empty"}), + ] + + vector.add_texts(docs, duplicate_check=True) + + vector._filter_duplicate_texts.assert_called_once_with([docs[0]]) + vector._embeddings.embed_documents.assert_not_called() + vector._vector_processor.create.assert_not_called() + + def test_vector_delegation_methods(vector_factory_module): vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) vector._embeddings = MagicMock() @@ -329,19 +485,11 @@ def test_vector_delegation_methods(vector_factory_module): def test_search_by_file_handles_missing_and_existing_upload(vector_factory_module, monkeypatch): - class _Field: - def __eq__(self, value): - return value - - upload_query = MagicMock() - upload_query.where.return_value = upload_query - vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) vector._embeddings = MagicMock() vector._vector_processor = MagicMock() mock_session = SimpleNamespace(get=lambda _model, _id: None) - monkeypatch.setattr(vector_factory_module, "UploadFile", SimpleNamespace(id=_Field())) monkeypatch.setattr(vector_factory_module, "db", SimpleNamespace(session=mock_session)) assert vector.search_by_file("file-1") == [] diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py index a7b7c1595b..007a76aa66 100644 --- a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -721,6 +721,30 @@ class TestDatasetDocumentStoreMultimodelBinding: mock_db.session.add.assert_not_called() + def test_add_multimodel_documents_binding_with_none_document_id(self): + """Test that no bindings are added when document_id is None.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + + mock_attachment = MagicMock(spec=AttachmentDocument) + mock_attachment.metadata = {"doc_id": "attachment-1"} + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id=None, + ) + + store.add_multimodel_documents_binding("seg-1", [mock_attachment]) + + mock_db.session.add.assert_not_called() + class TestDatasetDocumentStoreAddDocumentsUpdateChild: """Tests for add_documents when updating existing documents with children.""" diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index 3563186186..051a1455ae 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -12,11 +12,11 @@ from unittest.mock import Mock, patch import numpy as np import pytest -from graphon.model_runtime.entities.model_entities import ModelPropertyKey -from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from sqlalchemy.exc import IntegrityError from core.rag.embedding.cached_embedding import CacheEmbedding +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 408cf14a51..4b8175b0b4 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -49,6 +49,10 @@ from unittest.mock import Mock, patch import numpy as np import pytest +from sqlalchemy.exc import IntegrityError + +from core.entities.embedding_type import EmbeddingInputType +from core.rag.embedding.cached_embedding import CacheEmbedding from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from graphon.model_runtime.errors.invoke import ( @@ -56,10 +60,6 @@ from graphon.model_runtime.errors.invoke import ( InvokeConnectionError, InvokeRateLimitError, ) -from sqlalchemy.exc import IntegrityError - -from core.entities.embedding_type import EmbeddingInputType -from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 64eb89590a..b9f2449cfb 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -354,15 +354,44 @@ def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): WordExtractor("not-a-file", "tenant", "user") -def test_del_closes_temp_file(): +def test_close_closes_temp_file(): extractor = object.__new__(WordExtractor) + extractor._closed = False extractor.temp_file = MagicMock() - WordExtractor.__del__(extractor) + extractor.close() extractor.temp_file.close.assert_called_once() +def test_close_is_idempotent(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + + extractor.close() + extractor.close() + + extractor.temp_file.close.assert_called_once() + + +async def _async_close() -> None: + return None + + +def test_close_closes_awaitable_close_result(): + extractor = object.__new__(WordExtractor) + extractor._closed = False + extractor.temp_file = MagicMock() + close_result = _async_close() + extractor.temp_file.close = MagicMock(return_value=close_result) + + extractor.close() + + assert close_result.cr_frame is None + extractor.temp_file.close.assert_called_once() + + def test_extract_images_handles_invalid_external_cases(monkeypatch): class FakeTargetRef: def __contains__(self, item): diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index d4b987c832..4ba4d54fa0 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -1,15 +1,16 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent -from graphon.model_runtime.entities.model_entities import ModelFeature from core.entities.knowledge_entities import PreviewDetail from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelFeature class TestParagraphIndexProcessor: @@ -71,7 +72,9 @@ class TestParagraphIndexProcessor: with pytest.raises(ValueError, match="No rules found in process rule"): processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) - def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + def test_transform_validates_segmentation( + self, processor: ParagraphIndexProcessor, process_rule: dict[str, Any] + ) -> None: rules_without_segmentation = SimpleNamespace(segmentation=None) with patch( @@ -84,7 +87,9 @@ class TestParagraphIndexProcessor: process_rule={"mode": "custom", "rules": {"enabled": True}}, ) - def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + def test_transform_builds_split_documents( + self, processor: ParagraphIndexProcessor, process_rule: dict[str, Any] + ) -> None: source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) splitter = Mock() splitter.split_documents.return_value = [ diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py index c241b44d52..8ef0e046ef 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -258,10 +258,10 @@ class TestParentChildIndexProcessor: session.commit.assert_called_once() def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: - segment_query = Mock() - segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(id="seg-1")] session = Mock() - session.query.return_value = segment_query + session.scalars.return_value = scalars_result session_ctx = MagicMock() session_ctx.__enter__.return_value = session session_ctx.__exit__.return_value = False diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 98c47bec8f..bfae9001b7 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, patch import pandas as pd @@ -77,7 +78,7 @@ class TestQAIndexProcessor: processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) def test_transform_preview_calls_formatter_once( - self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + self, processor: QAIndexProcessor, process_rule: dict[str, Any], fake_flask_app ) -> None: document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) split_node = Document(page_content=".question", metadata={}) @@ -119,7 +120,7 @@ class TestQAIndexProcessor: mock_format.assert_called_once() def test_transform_non_preview_uses_thread_batches( - self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + self, processor: QAIndexProcessor, process_rule: dict[str, Any], fake_flask_app ) -> None: documents = [ Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}), @@ -220,10 +221,10 @@ class TestQAIndexProcessor: self, processor: QAIndexProcessor, dataset: Mock ) -> None: mock_segment = SimpleNamespace(id="seg-1") - mock_query = Mock() - mock_query.filter.return_value.all.return_value = [mock_segment] + scalars_result = Mock() + scalars_result.all.return_value = [mock_segment] mock_session = Mock() - mock_session.query.return_value = mock_query + mock_session.scalars.return_value = scalars_result session_context = MagicMock() session_context.__enter__.return_value = mock_session session_context.__exit__.return_value = False diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index 641c5d9ba0..b4bb343533 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -53,7 +53,6 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm.exc import ObjectDeletedError from core.errors.error import ProviderTokenNotInitError @@ -64,6 +63,7 @@ from core.indexing_runner import ( ) from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document +from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument @@ -1484,11 +1484,8 @@ class TestIndexingRunnerProcessChunk: mock_dependencies["redis"].get.return_value = None - # Mock database query for segment updates - mock_query = MagicMock() - mock_dependencies["db"].session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.update.return_value = None + # Mock database update for segment status + mock_dependencies["db"].session.execute.return_value = None # Create a proper context manager mock mock_context = MagicMock() diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index c279b00d3b..8bc7dbf70d 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -17,7 +17,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_manager import ModelInstance from core.rag.index_processor.constant.doc_type import DocType @@ -29,6 +28,7 @@ from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner +from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance() -> ModelInstance: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index b98fec3854..b556ddf528 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -1,14 +1,12 @@ import threading from contextlib import contextmanager, nullcontext from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 import pytest from flask import Flask, current_app -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.model_runtime.entities.model_entities import ModelFeature -from sqlalchemy import column from core.app.app_config.entities import ( DatasetEntity, @@ -35,6 +33,8 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset from models.enums import CreatorUserRole @@ -46,7 +46,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -1106,11 +1106,11 @@ class TestRetrievalService: def test_deduplicate_documents_non_dify_provider(self): """ - Test deduplication with non-dify provider documents. + Test deduplication with non-dify provider documents that have no doc_id. Verifies: - - External provider documents use content-based deduplication - - Different providers are handled correctly + - External provider documents without doc_id use content-based deduplication + - Identical content from the same provider is collapsed to one result """ # Arrange doc1 = Document( @@ -1131,7 +1131,96 @@ class TestRetrievalService: # Assert # External documents without doc_id should use content-based dedup - assert len(result) >= 1 + assert len(result) == 1 + + def test_deduplicate_documents_non_dify_provider_with_doc_id_different_sources(self): + """ + Regression test for issue #35707. + + Two chunks from different source documents share identical text content but carry + different doc_ids. Before the fix, non-dify providers were forced into content-based + deduplication and the second chunk was silently dropped. After the fix, doc_id is used + as the dedup key for any provider that exposes it, so both chunks must be retained. + + Verifies: + - Non-dify provider documents with different doc_ids are NOT deduplicated even when + their page_content is identical. + """ + # Arrange — same content, different doc_ids, non-dify provider (e.g. Weaviate / Qdrant) + doc_a = Document( + page_content="Shared identical content", + metadata={"doc_id": "doc-from-file-a", "score": 0.85}, + provider="weaviate", + ) + doc_b = Document( + page_content="Shared identical content", + metadata={"doc_id": "doc-from-file-b", "score": 0.82}, + provider="weaviate", + ) + + # Act + result = RetrievalService._deduplicate_documents([doc_a, doc_b]) + + # Assert — both documents must be kept; losing either silently drops a source citation + assert len(result) == 2 + doc_ids = {doc.metadata["doc_id"] for doc in result} + assert doc_ids == {"doc-from-file-a", "doc-from-file-b"} + + def test_deduplicate_documents_non_dify_provider_with_same_doc_id(self): + """ + Test that non-dify provider documents sharing the same doc_id are deduplicated by + doc_id key (not by content), and the higher-scored duplicate is retained. + + Verifies: + - doc_id-based deduplication now applies to any provider, not only "dify" + - The document with the highest score wins when doc_ids collide + """ + # Arrange + doc_low = Document( + page_content="Content A", + metadata={"doc_id": "chunk-1", "score": 0.5}, + provider="qdrant", + ) + doc_high = Document( + page_content="Content A", + metadata={"doc_id": "chunk-1", "score": 0.9}, + provider="qdrant", + ) + + # Act + result = RetrievalService._deduplicate_documents([doc_low, doc_high]) + + # Assert + assert len(result) == 1 + assert result[0].metadata["score"] == 0.9 + + def test_deduplicate_documents_dify_provider_without_doc_id_falls_back_to_content(self): + """ + Test that a dify provider document without doc_id still falls back to content-based + deduplication (no regression from original behaviour). + + Verifies: + - Absence of doc_id triggers content-based dedup regardless of provider + - First occurrence is kept when content is identical + """ + # Arrange — dify docs with no doc_id, same content + doc1 = Document( + page_content="Same content", + metadata={"score": 0.8}, + provider="dify", + ) + doc2 = Document( + page_content="Same content", + metadata={"score": 0.9}, + provider="dify", + ) + + # Act + result = RetrievalService._deduplicate_documents([doc1, doc2]) + + # Assert — collapsed to one; first-seen wins (no score comparison in content branch) + assert len(result) == 1 + assert result[0].metadata["score"] == 0.8 # ==================== Metadata Filtering Tests ==================== @@ -2022,7 +2111,7 @@ def create_mock_document_methods( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -2417,12 +2506,11 @@ class TestDatasetRetrievalKnowledgeRetrieval: mock_document.data_source_type = "upload_file" mock_document.doc_metadata = {} - mock_session.query.return_value.filter.return_value.all.return_value = [ - mock_dataset_from_db - ] - mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( - [mock_dataset_from_db, mock_document] - ) + mock_datasets = MagicMock() + mock_datasets.all.return_value = [mock_dataset_from_db] + mock_documents = MagicMock() + mock_documents.all.return_value = [mock_document] + mock_session.scalars.side_effect = [mock_datasets, mock_documents] # Act result = dataset_retrieval.knowledge_retrieval(request) @@ -4039,21 +4127,9 @@ class TestDatasetRetrievalAdditionalHelpers: def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None: session = Mock() - subquery_query = Mock() - subquery_query.where.return_value = subquery_query - subquery_query.group_by.return_value = subquery_query - subquery_query.having.return_value = subquery_query - subquery_query.subquery.return_value = SimpleNamespace( - c=SimpleNamespace( - dataset_id=column("dataset_id"), available_document_count=column("available_document_count") - ) - ) - - dataset_query = Mock() - dataset_query.outerjoin.return_value = dataset_query - dataset_query.where.return_value = dataset_query - dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] - session.query.side_effect = [subquery_query, dataset_query] + scalars_result = Mock() + scalars_result.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] + session.scalars.return_value = scalars_result session_ctx = MagicMock() session_ctx.__enter__.return_value = session @@ -4104,7 +4180,7 @@ def _doc( dataset_id: str = "dataset-1", document_id: str = "document-1", doc_id: str = "node-1", - extra: dict | None = None, + extra: dict[str, Any] | None = None, ) -> Document: metadata = { "score": score, @@ -4902,9 +4978,6 @@ class TestInternalHooksCoverage: _scalars(segments), _scalars(bindings), ] - query = Mock() - query.where.return_value = query - session.query.return_value = query session_ctx = MagicMock() session_ctx.__enter__.return_value = session session_ctx.__exit__.return_value = False @@ -4919,7 +4992,7 @@ class TestInternalHooksCoverage: ): retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1}) - query.update.assert_called_once() + session.execute.assert_called_once() mock_trace.assert_called_once() def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py index 48782515d0..aace419d15 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 @@ -55,7 +56,7 @@ def create_mock_document( doc_id: str, score: float = 0.8, provider: str = "dify", - additional_metadata: dict | None = None, + additional_metadata: dict[str, Any] | None = None, ) -> Document: """ Create a mock Document object for testing. @@ -450,12 +451,11 @@ class TestDatasetRetrievalKnowledgeRetrieval: mock_document.data_source_type = "upload_file" mock_document.doc_metadata = {} - mock_session.query.return_value.filter.return_value.all.return_value = [ - mock_dataset_from_db - ] - mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( - [mock_dataset_from_db, mock_document] - ) + mock_datasets = MagicMock() + mock_datasets.all.return_value = [mock_dataset_from_db] + mock_documents = MagicMock() + mock_documents.all.return_value = [mock_document] + mock_session.scalars.side_effect = [mock_datasets, mock_documents] # Act result = dataset_retrieval.knowledge_retrieval(request) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py index 5a2ecb8220..43c521dcfd 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -1,8 +1,7 @@ from unittest.mock import Mock -from graphon.model_runtime.entities.llm_entities import LLMUsage - from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter +from graphon.model_runtime.entities.llm_entities import LLMUsage class TestFunctionCallMultiDatasetRouter: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index 539ac0f849..c56528cf55 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -1,13 +1,12 @@ from types import SimpleNamespace from unittest.mock import Mock, patch +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.model_runtime.entities.message_entities import PromptMessageRole from graphon.model_runtime.entities.model_entities import ModelType -from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish -from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter - class TestReactMultiDatasetRouter: def test_invoke_returns_none_when_no_tools(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e229d5fc1a..3d3322094e 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -9,10 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowType from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 7dbf78d0f0..05b4f3a053 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -9,14 +9,14 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) from graphon.enums import BuiltinNodeTypes - -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.repositories.factory import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 0fc82dda53..18ae9fafc8 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,11 +7,6 @@ from datetime import datetime from types import SimpleNamespace import pytest -from graphon.nodes.human_input.entities import ( - FormDefinition, - UserAction, -) -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( HumanInputFormRecord, @@ -19,13 +14,18 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, MemberRecipient, ) +from graphon.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 8ff0e40587..4248782d93 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -9,8 +9,6 @@ from typing import Any from unittest.mock import MagicMock import pytest -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction -from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( FormCreateParams, @@ -23,7 +21,7 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from core.workflow.human_input_compat import ( +from core.workflow.human_input_adapter import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -31,6 +29,8 @@ from core.workflow.human_input_compat import ( MemberRecipient, WebAppDeliveryMethod, ) +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index e5c3e85487..a08c5729cb 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -3,12 +3,12 @@ from unittest.mock import MagicMock from uuid import uuid4 import pytest -from graphon.entities import WorkflowExecution -from graphon.enums import WorkflowExecutionStatus, WorkflowType from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 5b4d26b780..6af7b02d4c 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -10,12 +10,6 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest -from graphon.entities import WorkflowNodeExecution -from graphon.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) from sqlalchemy import Engine, create_engine from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker @@ -29,6 +23,12 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( _find_first, _replace_or_append_offload, ) +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from models import Account, EndUser from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index 84fe522388..abdbc72085 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 27729e7f06..5af1376a0a 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -11,17 +11,17 @@ from datetime import UTC, datetime from typing import Any from unittest.mock import MagicMock -from graphon.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from graphon.enums import BuiltinNodeTypes from sqlalchemy import Engine from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index ac65d0c02b..eab0176f41 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,15 +1,14 @@ import json from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig - from models.workflow import Workflow def test_file_to_dict(): file = File( - id="file1", + file_id="file1", tenant_id="tenant1", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", storage_key="storage_key", diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index f5efb78b61..5a7e7e30a5 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock, patch import pytest import redis -from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration -from core.model_manager import LBModelManager +from core.model_manager import LBModelManager, ModelManager from extensions.ext_redis import redis_client +from graphon.model_runtime.entities.model_entities import ModelType @pytest.fixture @@ -40,6 +40,29 @@ def lb_model_manager(): return lb_model_manager +def test_model_manager_with_cache_enabled_reuses_stored_credentials(): + """With ``enable_credentials_cache=True``, later calls for the same key return cached creds.""" + provider_manager = MagicMock() + bundle = MagicMock() + bundle.configuration.provider.provider = "openai" + bundle.configuration.tenant_id = "tenant-1" + bundle.configuration.model_settings = [] + bundle.model_type_instance.model_type = ModelType.LLM + get_creds = MagicMock(return_value={"api_key": "first"}) + bundle.configuration.get_current_credentials = get_creds + provider_manager.get_provider_model_bundle.return_value = bundle + + manager = ModelManager(provider_manager, enable_credentials_cache=True) + first = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4") + assert first.credentials == {"api_key": "first"} + get_creds.assert_called_once() + + get_creds.return_value = {"api_key": "second"} + second = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4") + assert second.credentials == {"api_key": "first"} + get_creds.assert_called_once() + + def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager): # initialize redis client redis_client.initialize(redis.Redis()) diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 331166fe63..b19a21d7f4 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -1,15 +1,6 @@ from unittest.mock import Mock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormOption, - FormType, - ProviderEntity, -) from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus from core.entities.provider_entities import ( @@ -21,6 +12,15 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormOption, + FormType, + ProviderEntity, +) from models.provider import Provider, ProviderType diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index ee26172459..02f12fb3b4 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -2,12 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest -from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel from models.provider_ids import ModelProviderID @@ -372,6 +372,78 @@ def test_get_configurations_binds_manager_runtime_to_provider_configuration( provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) +def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerFixture, mock_provider_entity): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers, + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory) as mock_factory_cls, + patch( + "core.provider_manager.ProviderConfiguration", + return_value=provider_configuration, + ) as mock_provider_configuration, + ): + first = manager.get_configurations("tenant-id") + second = manager.get_configurations("tenant-id") + + assert first is second + mock_get_all_providers.assert_called_once_with("tenant-id") + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_provider_configuration.assert_called_once() + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + +def test_clear_configurations_cache_rebuilds_requested_tenant(mocker: MockerFixture, mock_provider_entity): + manager = _build_provider_manager(mocker) + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + provider_configuration_first = Mock() + provider_configuration_second = Mock() + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers, + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch( + "core.provider_manager.ProviderConfiguration", + side_effect=[provider_configuration_first, provider_configuration_second], + ) as mock_provider_configuration, + ): + first = manager.get_configurations("tenant-id") + manager.clear_configurations_cache("tenant-id") + second = manager.get_configurations("tenant-id") + + assert first is not second + assert mock_get_all_providers.call_count == 2 + assert mock_provider_configuration.call_count == 2 + provider_configuration_first.bind_model_runtime.assert_called_once_with(manager._model_runtime) + provider_configuration_second.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): manager = _build_provider_manager(mocker) provider_configuration = Mock() @@ -498,8 +570,7 @@ def test_get_all_providers_normalizes_provider_names_with_model_provider_id() -> session.scalars.return_value = [openai_provider, gemini_provider] with ( - patch("core.provider_manager.db", SimpleNamespace(engine=object())), - patch("core.provider_manager.Session", return_value=_build_session_context(session)), + patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)), ): result = ProviderManager._get_all_providers("tenant-id") @@ -523,8 +594,7 @@ def test_provider_grouping_helpers_group_records_by_provider_name(method_name: s session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record] with ( - patch("core.provider_manager.db", SimpleNamespace(engine=object())), - patch("core.provider_manager.Session", return_value=_build_session_context(session)), + patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)), ): result = getattr(ProviderManager, method_name)("tenant-id") @@ -539,8 +609,7 @@ def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() -> session.scalars.return_value = [openai_preference, anthropic_preference] with ( - patch("core.provider_manager.db", SimpleNamespace(engine=object())), - patch("core.provider_manager.Session", return_value=_build_session_context(session)), + patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)), ): result = ProviderManager._get_all_preferred_model_providers("tenant-id") @@ -554,13 +623,13 @@ def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_ with ( patch("core.provider_manager.redis_client.get", return_value=b"False"), patch("core.provider_manager.FeatureService.get_features") as mock_get_features, - patch("core.provider_manager.Session") as mock_session_cls, + patch("core.provider_manager.session_factory.create_session") as mock_create_session, ): result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") assert result == {} mock_get_features.assert_not_called() - mock_session_cls.assert_not_called() + mock_create_session.assert_not_called() def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None: @@ -570,14 +639,13 @@ def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_conf session.scalars.return_value = [openai_config, anthropic_config] with ( - patch("core.provider_manager.db", SimpleNamespace(engine=object())), patch("core.provider_manager.redis_client.get", return_value=None), patch("core.provider_manager.redis_client.setex") as mock_setex, patch( "core.provider_manager.FeatureService.get_features", return_value=SimpleNamespace(model_load_balancing_enabled=True), ), - patch("core.provider_manager.Session", return_value=_build_session_context(session)), + patch("core.provider_manager.session_factory.create_session", return_value=_build_session_context(session)), ): result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index 5d744f88c9..1ff81f6120 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -6,13 +6,13 @@ from typing import Any from unittest.mock import patch import pytest -from graphon.model_runtime.entities.message_entities import UserPromptMessage from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType +from graphon.model_runtime.entities.message_entities import UserPromptMessage class _BuiltinDummyTool(BuiltinTool): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index ee0ce51eec..c7829fc0d7 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -6,8 +6,6 @@ from datetime import date from types import SimpleNamespace import pytest -from graphon.file import FileType -from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -29,6 +27,8 @@ from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.errors import ToolInvokeError +from graphon.file import FileType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: diff --git a/api/tests/unit_tests/core/tools/test_custom_tool.py b/api/tests/unit_tests/core/tools/test_custom_tool.py index 79b8eaaa87..f35546b025 100644 --- a/api/tests/unit_tests/core/tools/test_custom_tool.py +++ b/api/tests/unit_tests/core/tools/test_custom_tool.py @@ -1,6 +1,7 @@ from __future__ import annotations from types import SimpleNamespace +from typing import Any import httpx import pytest @@ -14,7 +15,7 @@ from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvo from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -def _build_tool(*, openapi: dict | None = None) -> ApiTool: +def _build_tool(*, openapi: dict[str, Any] | None = None) -> ApiTool: entity = ToolEntity( identity=ToolIdentity( author="author", diff --git a/api/tests/unit_tests/core/tools/test_tool_engine.py b/api/tests/unit_tests/core/tools/test_tool_engine.py index 40c107667c..cd16557ef6 100644 --- a/api/tests/unit_tests/core/tools/test_tool_engine.py +++ b/api/tests/unit_tests/core/tools/test_tool_engine.py @@ -260,6 +260,28 @@ def test_agent_invoke_engine_meta_error(): assert error_meta.error == "meta failure" +def test_convert_tool_response_excludes_variable_messages(): + """Regression test for issue #34723. + + WorkflowTool._invoke yields VARIABLE, TEXT, and suppressed-JSON messages. + _convert_tool_response_to_str must skip VARIABLE messages so that the + returned string contains only the TEXT representation and not a + duplicated, garbled Pydantic repr of the same data. + """ + tool = _build_tool() + outputs = {"reports": "hello"} + messages = [ + tool.create_variable_message(variable_name="reports", variable_value="hello"), + tool.create_text_message('{"reports": "hello"}'), + tool.create_json_message(outputs, suppress_output=True), + ] + + result = ToolEngine._convert_tool_response_to_str(messages) + + assert result == '{"reports": "hello"}' + assert "variable_name" not in result + + def test_agent_invoke_tool_invoke_error(): tool = _build_tool(with_llm_parameter=True) callback = Mock() diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index 2889cb9db1..ccffdf16d1 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -12,9 +12,9 @@ from unittest.mock import MagicMock, Mock, patch import httpx import pytest -from graphon.file import FileTransferMethod from core.tools.tool_file_manager import ToolFileManager +from graphon.file import FileTransferMethod def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py index 8c0e7e9419..e13f430f9b 100644 --- a/api/tests/unit_tests/core/tools/test_tool_label_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -2,7 +2,7 @@ from __future__ import annotations from types import SimpleNamespace from typing import Any -from unittest.mock import PropertyMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -12,11 +12,13 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +# Create a mock class for testing abstract/base classes class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController): def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): return None +# Factory function to create a "lightweight" controller for testing def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController: controller = object.__new__(ApiToolProviderController) controller.provider_id = provider_id @@ -29,6 +31,7 @@ def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderContr return controller +# Test pure logic: filtering and deduplication def test_tool_label_manager_filter_tool_labels(): filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"]) assert set(filtered) == {"search", "news"} @@ -36,22 +39,68 @@ def test_tool_label_manager_filter_tool_labels(): def test_tool_label_manager_update_tool_labels_db(): + """ + Test the database update logic for tool labels. + Focus: Verify that labels are filtered, de-duplicated, and safely handled within a database session. + """ + # 1. Setup expected data from the controller controller = _api_controller("api-1") - with patch("core.tools.tool_label_manager.db") as mock_db: + expected_id = controller.provider_id + expected_type = controller.provider_type + + # 2. Patching External Dependencies + # - We patch 'db' to prevent Flask from trying to access a real database. + # - We patch 'sessionmaker' to intercept and control the creation of SQLAlchemy sessions. + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + # 3. Constructing the "Mocking Chain" + # In the business logic, we use: with sessionmaker(db.engine).begin() as _session: + # We need to link our 'mock_session' to the end of this complex context manager chain: + # Step A: sessionmaker(db.engine) -> returns an object (mock_sessionmaker.return_value) + # Step B: .begin() -> returns a context manager (begin.return_value) + # Step C: with ... as _session: -> calls __enter__(), and _session gets the __enter__.return_value + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + + # 4. Trigger the logic under test + # Input: ["search", "search", "invalid"] + # Logic: + # - "invalid" should be filtered out (not in default_tool_label_name_list). + # - The duplicate "search" should be merged (unique labels). ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) - mock_db.session.execute.assert_called_once() - # only one valid unique label should be inserted. - assert mock_db.session.add.call_count == 1 - mock_db.session.commit.assert_called_once() + # 5. Behavior Assertion: DELETE operation + # Verify that the manager first attempts to clear existing labels for this specific tool. + # This ensures the update is idempotent. + mock_session.execute.assert_called_once() + + # 6. Behavior Assertion: INSERT operation + # Verify that only ONE valid label ("search") was added after filtering and deduplication. + # If call_count == 1, it proves filter_tool_labels() worked as expected. + assert mock_session.add.call_count == 1 + + # 7. State Assertion: Data Integrity & Isolation + # Inspect the actual object passed to session.add() to ensure it has correct properties. + # This confirms that the data isolation (tool_id + tool_type) we refactored is active. + call_args = mock_session.add.call_args + added_label = call_args[0][0] # Retrieve the ToolLabelBinding instance + + assert added_label.label_name == "search", "The label name should be 'search' after filtering." + assert added_label.tool_id == expected_id, "The tool_id must match the provider_id for correct binding." + assert added_label.tool_type == expected_type, "Isolation failed: tool_type must be verified during update." +# Test error handling def test_tool_label_manager_update_tool_labels_unsupported(): with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type] +# Test retrieval logic def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): + # Mocking a property (@property) using PropertyMock with patch.object( _ConcreteBuiltinToolProviderController, "tool_labels", @@ -62,29 +111,67 @@ def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"] api = _api_controller("api-1") - with patch("core.tools.tool_label_manager.db") as mock_db: - mock_db.session.scalars.return_value.all.return_value = ["search", "news"] - labels = ToolLabelManager.get_tool_labels(api) - assert labels == ["search", "news"] + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + # Inject mock data into the query result: session.scalars(stmt).all() + mock_session.scalars.return_value.all.return_value = ["search", "news"] + + labels = ToolLabelManager.get_tool_labels(api) + assert labels == ["search", "news"] + + +def test_tool_label_manager_get_tool_labels_unsupported(): + """ + Negative Test: Ensure get_tool_labels raises ValueError for unsupported controller types. + This protects the internal API contract against accidental regressions during refactoring. + """ + # Passing a generic object() which doesn't match Api, Workflow, or Builtin controllers. with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type] +# Test batch processing and mapping def test_tool_label_manager_get_tools_labels_batch(): assert ToolLabelManager.get_tools_labels([]) == {} api = _api_controller("api-1") wf = _workflow_controller("wf-1") + + # SimpleNamespace is a quick way to simulate SQLAlchemy row objects records = [ SimpleNamespace(tool_id="api-1", label_name="search"), SimpleNamespace(tool_id="api-1", label_name="news"), SimpleNamespace(tool_id="wf-1", label_name="utilities"), ] - with patch("core.tools.tool_label_manager.db") as mock_db: - mock_db.session.scalars.return_value.all.return_value = records + + with ( + patch("core.tools.tool_label_manager.db"), + patch("core.tools.tool_label_manager.sessionmaker") as mock_sessionmaker, + ): + mock_session = MagicMock() + mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session + + # Simulating the batch query result + mock_session.scalars.return_value.all.return_value = records + labels = ToolLabelManager.get_tools_labels([api, wf]) + + # Verify the final dictionary mapping assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]} + +def test_tool_label_manager_get_tools_labels_unsupported(): + """ + Negative Test: Ensure get_tools_labels raises ValueError if the list contains + unsupported controller types, even alongside valid ones. + """ + api = _api_controller("api-1") + + # Passing a list with one valid controller and one invalid object() with pytest.raises(ValueError, match="Unsupported tool type"): ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item] diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 31b68f0b3f..c9b3dfb186 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -637,7 +637,7 @@ def test_list_default_builtin_providers_for_postgres_and_mysql(): for scheme in ("postgresql", "mysql"): session = Mock() session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")] - session.query.return_value.where.return_value.all.return_value = provider_records + session.scalars.return_value = iter(provider_records) with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)): with patch("core.tools.tool_manager.db") as mock_db: @@ -925,3 +925,78 @@ def test_convert_tool_parameters_type_constant_branch(): ) assert constant == {"text": "fixed"} + + +def test_convert_tool_parameters_type_model_selector_from_legacy_top_level_config(): + model_param = ToolParameter.get_simple_instance( + name="vision_llm_model", + llm_description="vision model", + typ=ToolParameter.ToolParameterType.MODEL_SELECTOR, + required=True, + ) + model_param.form = ToolParameter.ToolParameterForm.FORM + variable_pool = Mock() + + runtime_parameters = ToolManager._convert_tool_parameters_type( + parameters=[model_param], + variable_pool=variable_pool, + tool_configurations={ + "vision_llm_model": { + "type": "constant", + "value": "", + "provider": "langgenius/tongyi/tongyi", + "model": "qwen3-vl-plus", + "model_type": "llm", + "mode": "chat", + } + }, + typ="workflow", + ) + + assert runtime_parameters == { + "vision_llm_model": { + "provider": "langgenius/tongyi/tongyi", + "model": "qwen3-vl-plus", + "model_type": "llm", + "mode": "chat", + } + } + + +def test_convert_tool_parameters_type_model_selector_from_constant_value_config(): + model_param = ToolParameter.get_simple_instance( + name="tts_model", + llm_description="tts model", + typ=ToolParameter.ToolParameterType.MODEL_SELECTOR, + required=True, + ) + model_param.form = ToolParameter.ToolParameterForm.FORM + variable_pool = Mock() + + runtime_parameters = ToolManager._convert_tool_parameters_type( + parameters=[model_param], + variable_pool=variable_pool, + tool_configurations={ + "tts_model": { + "type": "constant", + "value": { + "provider": "langgenius/tongyi/tongyi", + "model": "qwen3-tts-flash", + "model_type": "tts", + "language": "Chinese", + "voice": "Cherry", + }, + } + }, + typ="workflow", + ) + + assert runtime_parameters == { + "tts_model": { + "provider": "langgenius/tongyi/tongyi", + "model": "qwen3-tts-flash", + "model_type": "tts", + "language": "Chinese", + "voice": "Cherry", + } + } diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py index 6454a5bcd1..5f34135af4 100644 --- a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest import core.tools.utils.message_transformer as mt @@ -13,7 +15,7 @@ class _FakeToolFile: class _FakeToolFileManager: """Fake ToolFileManager to capture the mimetype passed in.""" - last_call: dict | None = None + last_call: dict[str, Any] | None = None def __init__(self, *args, **kwargs): pass diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index 52f262e1cf..44785f939c 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -10,9 +10,12 @@ from __future__ import annotations from decimal import Decimal from types import SimpleNamespace +from typing import Any from unittest.mock import Mock, patch import pytest + +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils from graphon.model_runtime.entities.model_entities import ModelPropertyKey from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -22,10 +25,8 @@ from graphon.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) -from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils - -def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: +def _mock_model_instance(*, schema: dict[str, Any] | None = None) -> SimpleNamespace: model_type_instance = Mock() model_type_instance.get_model_schema.return_value = ( SimpleNamespace(model_properties=schema or {}) if schema is not None else None diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index 40f91b12a0..99a90f3b67 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -1,4 +1,5 @@ from json.decoder import JSONDecodeError +from typing import Any from unittest.mock import Mock, patch import pytest @@ -16,7 +17,7 @@ def app(): return app -def test_parse_openapi_to_tool_bundle_operation_id(app): +def test_parse_openapi_to_tool_bundle_operation_id(app: Flask): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, @@ -62,7 +63,7 @@ def test_parse_openapi_to_tool_bundle_operation_id(app): assert tool_bundles[2].operation_id == "createResource" -def test_parse_openapi_to_tool_bundle_properties_all_of(app): +def test_parse_openapi_to_tool_bundle_properties_all_of(app: Flask): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, @@ -117,7 +118,7 @@ def test_parse_openapi_to_tool_bundle_properties_all_of(app): # assert set(tool_bundles[0].parameters[0].options) == {"option1", "option2", "option3"} -def test_parse_openapi_to_tool_bundle_default_value_type_casting(app): +def test_parse_openapi_to_tool_bundle_default_value_type_casting(app: Flask): """ Test that default values are properly cast to match parameter types. This addresses the issue where array default values like [] cause validation errors @@ -259,8 +260,8 @@ def test_parse_openapi_to_tool_bundle_server_env_and_refs(app): }, } - extra_info: dict = {} - warning: dict = {} + extra_info: dict[str, Any] = {} + warning: dict[str, Any] = {} with app.test_request_context(headers={"X-Request-Env": "prod"}): bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) @@ -298,7 +299,7 @@ def test_parse_swagger_to_openapi_branches(): } ) - warning: dict = {"seed": True} + warning: dict[str, Any] = {"seed": True} converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi( { "servers": [{"url": "https://x"}], diff --git a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py index 5691f33e65..6bb86ebe78 100644 --- a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py +++ b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py @@ -2,50 +2,50 @@ from __future__ import annotations import pytest -from core.tools.utils import system_oauth_encryption as oauth_encryption -from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter +from core.tools.utils import system_encryption as encryption +from core.tools.utils.system_encryption import EncryptionError, SystemEncrypter -def test_system_oauth_encrypter_roundtrip(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_roundtrip(): + encrypter = SystemEncrypter(secret_key="test-secret") payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"} - encrypted = encrypter.encrypt_oauth_params(payload) - decrypted = encrypter.decrypt_oauth_params(encrypted) + encrypted = encrypter.encrypt_params(payload) + decrypted = encrypter.decrypt_params(encrypted) assert encrypted assert dict(decrypted) == payload -def test_system_oauth_encrypter_decrypt_validates_input(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_decrypt_validates_input(): + encrypter = SystemEncrypter(secret_key="test-secret") with pytest.raises(ValueError, match="must be a string"): - encrypter.decrypt_oauth_params(123) # type: ignore[arg-type] + encrypter.decrypt_params(123) # type: ignore[arg-type] with pytest.raises(ValueError, match="cannot be empty"): - encrypter.decrypt_oauth_params("") + encrypter.decrypt_params("") -def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext(): - encrypter = SystemOAuthEncrypter(secret_key="test-secret") +def test_system_encrypter_raises_error_for_invalid_ciphertext(): + encrypter = SystemEncrypter(secret_key="test-secret") - with pytest.raises(OAuthEncryptionError, match="Decryption failed"): - encrypter.decrypt_oauth_params("not-base64") + with pytest.raises(EncryptionError, match="Decryption failed"): + encrypter.decrypt_params("not-base64") -def test_system_oauth_helpers_use_global_cached_instance(monkeypatch): - monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None) - monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret") +def test_system_helpers_use_global_cached_instance(monkeypatch): + monkeypatch.setattr(encryption, "_encrypter", None) + monkeypatch.setattr("core.tools.utils.system_encryption.dify_config.SECRET_KEY", "global-secret") - first = oauth_encryption.get_system_oauth_encrypter() - second = oauth_encryption.get_system_oauth_encrypter() + first = encryption.get_system_encrypter() + second = encryption.get_system_encrypter() assert first is second - encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"}) - assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"} + encrypted = encryption.encrypt_system_params({"k": "v"}) + assert encryption.decrypt_system_params(encrypted) == {"k": "v"} -def test_create_system_oauth_encrypter_factory(): - encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret") - assert isinstance(encrypter, SystemOAuthEncrypter) +def test_create_system_encrypter_factory(): + encrypter = encryption.create_system_encrypter(secret_key="factory-secret") + assert isinstance(encrypter, SystemEncrypter) diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index 0e3a7e623a..43f3fbd5c9 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,9 +1,9 @@ import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from graphon.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index 4767480a5a..5a585c609a 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -4,7 +4,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -14,6 +13,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from graphon.variables.input_entities import VariableEntity, VariableEntityType def _controller() -> WorkflowToolProviderController: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index c20edd7400..72a73dd936 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -11,7 +11,6 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest -from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -25,6 +24,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index 78622b78b6..fb7dc52838 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -8,10 +8,10 @@ and select_trigger_debug_events orchestrator. from __future__ import annotations from datetime import datetime +from typing import Any from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes, NodeType from core.plugin.entities.request import TriggerInvokeEventResponse from core.trigger.constants import ( @@ -27,10 +27,11 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent +from graphon.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID -def _make_poller_args(node_config: dict | None = None) -> dict: +def _make_poller_args(node_config: dict[str, Any] | None = None) -> dict[str, Any]: return { "tenant_id": "t1", "user_id": "u1", diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 7406b88270..9e07ea1b6d 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,23 +1,30 @@ import dataclasses +from typing import Annotated import orjson import pytest +from pydantic import BaseModel, Discriminator, Tag + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from graphon.file import File, FileTransferMethod, FileType from graphon.runtime import VariablePool from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( ArrayAnySegment, + ArrayBooleanSegment, ArrayFileSegment, ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + BooleanSegment, FileSegment, FloatSegment, IntegerSegment, NoneSegment, ObjectSegment, Segment, - SegmentUnion, StringSegment, get_segment_discriminator, ) @@ -42,11 +49,26 @@ from graphon.variables.variables import ( StringVariable, Variable, ) -from pydantic import BaseModel +from models.utils.file_input_compat import rebuild_serialized_graph_files_without_lookup -from core.helper import encrypter -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_variables_to_pool +type SegmentUnion = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] + ), + Discriminator(get_segment_discriminator), +] def _build_variable_pool( @@ -123,7 +145,7 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, @@ -160,7 +182,7 @@ class TestSegmentDumpAndLoad: assert restored == model def test_all_segments_serialization(self): - """Test serialization/deserialization of all segment types""" + """Test file-aware segment serialization through Dify's model boundary.""" # Create one instance of each segment type test_file = create_test_file() @@ -181,7 +203,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Segments(segments=all_segments) json_str = model.model_dump_json() - loaded = _Segments.model_validate_json(json_str) + loaded = _Segments.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all segments are preserved assert len(loaded.segments) == len(all_segments) @@ -202,7 +224,7 @@ class TestSegmentDumpAndLoad: assert loaded_segment.value == original.value def test_all_variables_serialization(self): - """Test serialization/deserialization of all variable types""" + """Test file-aware variable serialization through Dify's model boundary.""" # Create one instance of each variable type test_file = create_test_file() @@ -223,7 +245,7 @@ class TestSegmentDumpAndLoad: # Test serialization and deserialization model = _Variables(variables=all_variables) json_str = model.model_dump_json() - loaded = _Variables.model_validate_json(json_str) + loaded = _Variables.model_validate(rebuild_serialized_graph_files_without_lookup(orjson.loads(json_str))) # Verify all variables are preserved assert len(loaded.variables) == len(all_variables) diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index 37ecd2890b..d4e862220a 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,4 +1,5 @@ import pytest + from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import StringSegment from graphon.variables.types import ArrayValidation, SegmentType diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 09254e17a3..317fe99d37 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from typing import Any import pytest + from graphon.file import File, FileTransferMethod, FileType from graphon.variables.segment_group import SegmentGroup from graphon.variables.segments import ( @@ -34,7 +35,7 @@ def create_test_file( """Factory function to create File objects for testing.""" return File( tenant_id="test-tenant", - type=file_type, + file_type=file_type, transfer_method=transfer_method, filename=filename, extension=extension, diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 75b01bf42e..dae5e1ce98 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,4 +1,6 @@ import pytest +from pydantic import ValidationError + from graphon.variables import ( ArrayFileVariable, ArrayVariable, @@ -10,7 +12,6 @@ from graphon.variables import ( StringVariable, ) from graphon.variables.variables import VariableBase -from pydantic import ValidationError def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 41627f5e0b..025d79b25d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -5,12 +5,13 @@ Shared fixtures for ObservabilityLayer tests. from unittest.mock import MagicMock, patch import pytest -from graphon.enums import BuiltinNodeTypes from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider +from graphon.enums import BuiltinNodeTypes + @pytest.fixture def memory_span_exporter(): @@ -61,9 +62,8 @@ def mock_llm_node(): @pytest.fixture def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" - from graphon.nodes.tool.entities import ToolNodeData - from core.tools.entities.tool_entities import ToolProviderType + from graphon.nodes.tool.entities import ToolNodeData node = MagicMock() node.id = "test-tool-node-id" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 99d131737e..5d6667257f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -3,17 +3,16 @@ from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom +from core.app.workflow.layers.llm_quota import LLMQuotaLayer +from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.entities.commands import CommandType from graphon.graph_events import NodeRunSucceededEvent from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom -from core.app.workflow.layers.llm_quota import LLMQuotaLayer -from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance - def _build_dify_context() -> DifyRunContext: return DifyRunContext( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 9cf72763ee..919f15efd0 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -13,10 +13,10 @@ Test coverage: from unittest.mock import patch import pytest -from graphon.enums import BuiltinNodeTypes from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer +from graphon.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 88989db856..9f3e3b00b9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -1,18 +1,18 @@ -""" -Mock node factory for testing workflows with third-party service dependencies. +"""Mock node factory for third-party-service workflow tests. -This module provides a MockNodeFactory that automatically detects and mocks nodes -requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request). +The factory follows the same config adaptation path as production +`DifyNodeFactory.create_node()`, but swaps selected node classes for mock +implementations before instantiation. """ from typing import TYPE_CHECKING, Any +from core.workflow.human_input_adapter import adapt_node_config_for_graph +from core.workflow.node_factory import DifyNodeFactory from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node -from core.workflow.node_factory import DifyNodeFactory - from .test_mock_nodes import ( MockAgentNode, MockCodeNode, @@ -83,20 +83,20 @@ class MockNodeFactory(DifyNodeFactory): :param node_config: Node configuration dictionary :return: Node instance (real or mocked) """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config)) + node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_type = node_data.type # Check if this node type should be mocked if node_type in self._mock_node_types: - node_id = typed_node_config["id"] - # Create mock node instance mock_class = self._mock_node_types[node_type] + resolved_node_data = self._validate_resolved_node_data(mock_class, node_data) if node_type == BuiltinNodeTypes.CODE: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -105,8 +105,8 @@ class MockNodeFactory(DifyNodeFactory): ) elif node_type == BuiltinNodeTypes.HTTP_REQUEST: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -121,8 +121,8 @@ class MockNodeFactory(DifyNodeFactory): BuiltinNodeTypes.PARAMETER_EXTRACTOR, }: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -131,8 +131,8 @@ class MockNodeFactory(DifyNodeFactory): ) else: mock_instance = mock_class( - id=node_id, - config=typed_node_config, + node_id=node_id, + config=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -141,7 +141,7 @@ class MockNodeFactory(DifyNodeFactory): return mock_instance # For non-mocked node types, use parent implementation - return super().create_node(typed_node_config) + return super().create_node(node_config) def should_mock_node(self, node_type: NodeType) -> bool: """ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 8b7fbd1b30..f9819c47ec 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,6 +10,10 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock +from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.nodes.agent import AgentNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -27,11 +31,6 @@ from graphon.nodes.template_transform import TemplateTransformNode from graphon.nodes.tool import ToolNode from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError -from core.model_manager import ModelInstance -from core.workflow.node_runtime import DifyToolNodeRuntime -from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode - if TYPE_CHECKING: from graphon.entities import GraphInitParams from graphon.runtime import GraphRuntimeState @@ -56,13 +55,14 @@ class MockNodeMixin: def __init__( self, - id: str, - config: Mapping[str, Any], + node_id: str, + config: Any, + *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", mock_config: Optional["MockConfig"] = None, **kwargs: Any, - ): + ) -> None: if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)): kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) @@ -97,7 +97,7 @@ class MockNodeMixin: kwargs.setdefault("message_transformer", MagicMock()) super().__init__( - id=id, + node_id=node_id, config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index 8311a1e847..75bc6d05f7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,6 +4,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.entities import WorkflowStartReason from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig @@ -23,14 +30,6 @@ from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState, VariablePool - -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -140,8 +139,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( - id=start_config["id"], - config=start_config, + node_id=start_config["id"], + config=StartNodeData(title="Start", variables=[]), graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -155,8 +154,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_a_config = {"id": "human_a", "data": human_data.model_dump()} human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, + node_id=human_a_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -165,8 +164,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_b_config = {"id": "human_b", "data": human_data.model_dump()} human_b = HumanInputNode( - id=human_b_config["id"], - config=human_b_config, + node_id=human_b_config["id"], + config=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -183,8 +182,8 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor ) end_config = {"id": "end", "data": end_data.model_dump()} end_node = EndNode( - id=end_config["id"], - config=end_config, + node_id=end_config["id"], + config=end_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index b11f957677..7d23b63049 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -19,6 +19,11 @@ from functools import lru_cache from pathlib import Path from typing import Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.tools.utils.yaml_utils import _load_yaml_file +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from graphon.entities import GraphInitParams from graphon.graph import Graph from graphon.graph_engine import GraphEngine, GraphEngineConfig @@ -39,12 +44,6 @@ from graphon.variables import ( StringVariable, ) -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from core.workflow.system_variables import build_bootstrap_variables, build_system_variables -from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool - from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py index 12aec6edf2..ba1e74f3e0 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py @@ -5,6 +5,7 @@ from graphon.graph_events import ( NodeRunStreamChunkEvent, ) +from .test_mock_config import MockConfigBuilder from .test_table_runner import TableTestRunner @@ -44,3 +45,51 @@ def test_tool_in_chatflow(): assert stream_chunk_events[0].chunk == "hello, dify!", ( f"Expected chunk to be 'hello, dify!', but got {stream_chunk_events[0].chunk}" ) + + +def test_answer_can_render_llm_structured_output_in_chatflow(): + runner = TableTestRunner() + + fixture_data = runner.workflow_runner.load_fixture("basic_chatflow") + nodes = fixture_data["workflow"]["graph"]["nodes"] + answer_node = next(node for node in nodes if node["id"] == "answer") + answer_node["data"]["answer"] = "{{#llm.structured_output#}}" + + mock_config = ( + MockConfigBuilder() + .with_node_output( + "llm", + { + "text": "plain text", + "structured_output": {"type": "greeting"}, + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + "finish_reason": "stop", + }, + ) + .build() + ) + + graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( + fixture_data=fixture_data, + query="hello", + use_mock_factory=True, + mock_config=mock_config, + ) + + engine = GraphEngine( + workflow_id="test_workflow", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + ) + + events = list(engine.run()) + success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] + + assert success_events, "Workflow should complete successfully" + assert success_events[-1].outputs["answer"] == '{\n "type": "greeting"\n}' diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py index cbc920705c..1f4509af9a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -1,9 +1,8 @@ from unittest.mock import patch -from graphon.enums import BuiltinNodeTypes - from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer +from graphon.enums import BuiltinNodeTypes def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py index 59dd763b59..c86de7f6e6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -1,9 +1,8 @@ from types import SimpleNamespace from unittest.mock import Mock, patch -from graphon.model_runtime.entities.model_entities import ModelType - from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from graphon.model_runtime.entities.model_entities import ModelType def test_fetch_model_reuses_single_model_assembly(): diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 7195471eb6..ae9dae0646 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,15 +2,15 @@ import time import uuid from unittest.mock import MagicMock -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph -from graphon.nodes.answer.answer_node import AnswerNode -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables from extensions.ext_database import db +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.nodes.answer.entities import AnswerNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -67,20 +67,15 @@ def test_execute_answer(): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "answer", - "data": { - "title": "123", - "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", - }, - } - node = AnswerNode( - id=str(uuid.uuid4()), + node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, + config=AnswerNodeData( + title="123", + type="answer", + answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + ), ) # Mock db.session.close() @@ -91,3 +86,80 @@ def test_execute_answer(): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin." + + +def test_execute_answer_renders_structured_output_object_as_json() -> None: + init_params = build_test_graph_init_params( + workflow_id="1", + graph_config={"nodes": [], "edges": []}, + tenant_id="1", + app_id="1", + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + variable_pool = VariablePool( + system_variables=build_system_variables(user_id="aaa", files=[]), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + variable_pool.add(["1777539038857", "structured_output"], {"type": "greeting"}) + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + node = AnswerNode( + node_id=str(uuid.uuid4()), + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + config=AnswerNodeData( + title="123", + type="answer", + answer="{{#1777539038857.structured_output#}}", + ), + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["answer"] == '{\n "type": "greeting"\n}' + + +def test_execute_answer_falls_back_to_plain_selector_text_when_structured_output_missing() -> None: + init_params = build_test_graph_init_params( + workflow_id="1", + graph_config={"nodes": [], "edges": []}, + tenant_id="1", + app_id="1", + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + variable_pool = VariablePool( + system_variables=build_system_variables(user_id="aaa", files=[]), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + node = AnswerNode( + node_id=str(uuid.uuid4()), + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + config=AnswerNodeData( + title="123", + type="answer", + answer="{{#1777539038857.structured_output#}}", + ), + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["answer"] == "1777539038857.structured_output" diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 343bcd3919..ec4cef1955 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,10 +1,10 @@ import pytest + +from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node -from core.workflow.node_factory import get_node_type_classes_mapping - # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index b9371a34f4..ef0df55995 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,6 +1,7 @@ import types from collections.abc import Mapping +from core.workflow.node_factory import get_node_type_classes_mapping from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.base.node import Node @@ -13,8 +14,6 @@ from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) -from core.workflow.node_factory import get_node_type_classes_mapping - def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index d155124c50..ce0c9b79c6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,3 +1,4 @@ +from configs import dify_config from graphon.nodes.code.code_node import CodeNode from graphon.nodes.code.entities import CodeLanguage, CodeNodeData from graphon.nodes.code.exc import ( @@ -8,8 +9,6 @@ from graphon.nodes.code.exc import ( from graphon.nodes.code.limits import CodeNodeLimits from graphon.variables.types import SegmentType -from configs import dify_config - CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index fb03ae9998..d7ef781732 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,8 +1,8 @@ -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from core.workflow.nodes.datasource.entities import DatasourceNodeData +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: @@ -78,19 +78,16 @@ def test_datasource_node_delegates_to_manager_stream(mocker): mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) node = DatasourceNode( - id="n", - config={ - "id": "n", - "data": { - "type": "datasource", - "version": "1", - "title": "Datasource", - "provider_type": "plugin", - "provider_name": "p", - "plugin_id": "plug", - "datasource_name": "ds", - }, - }, + node_id="n", + config=DatasourceNodeData( + type="datasource", + version="1", + title="Datasource", + provider_type="plugin", + provider_name="p", + plugin_id="plug", + datasource_name="ds", + ), graph_init_params=gp, graph_runtime_state=gs, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index a5026b40cf..be7cc073db 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,4 +1,8 @@ import pytest + +from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables from graphon.file.file_manager import file_manager from graphon.nodes.http_request import ( BodyData, @@ -12,10 +16,6 @@ from graphon.nodes.http_request.exc import AuthorizationConfigError from graphon.nodes.http_request.executor import Executor from graphon.runtime import VariablePool -from configs import dify_config -from core.helper.ssrf_proxy import ssrf_proxy -from core.workflow.system_variables import default_system_variables - HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 4705b3f76e..2e89a2da3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,17 +3,17 @@ from typing import Any import httpx import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file.file_manager import file_manager -from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeData, HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( @@ -66,8 +66,8 @@ def test_get_default_config_uses_injected_http_request_config(): assert default_config["retry_config"]["max_retries"] == 7 -def test_get_default_config_with_malformed_http_request_config_raises_value_error(): - with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"): +def test_get_default_config_with_malformed_http_request_config_raises_type_error(): + with pytest.raises(TypeError, match="http_request_config must be an HttpRequestNodeConfig instance"): HttpRequestNode.get_default_config(filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"}) @@ -114,8 +114,8 @@ def _build_http_node( start_at=time.perf_counter(), ) return HttpRequestNode( - id="http-node", - config=node_config, + node_id="http-node", + config=HttpRequestNodeData.model_validate(node_data), graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d16e1233ac..07430498e5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,7 +1,6 @@ +from core.workflow.human_input_adapter import EmailDeliveryConfig, EmailRecipients from graphon.runtime import VariablePool -from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients - def test_render_body_template_replaces_variable_values(): config = EmailDeliveryConfig( diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index a2cdbbf132..0659984c76 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -10,6 +10,28 @@ from typing import Any from unittest.mock import MagicMock import pytest +from pydantic import ValidationError + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) +from core.workflow.human_input_adapter import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + EmailRecipientType, + ExternalRecipient, + MemberRecipient, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams from graphon.node_events import PauseRequestedEvent from graphon.node_events.node import StreamCompletedEvent @@ -28,28 +50,6 @@ from graphon.nodes.human_input.enums import ( ) from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool -from pydantic import ValidationError - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY -from core.repositories.human_input_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRecipientEntity, - HumanInputFormRepository, -) -from core.workflow.human_input_compat import ( - DeliveryMethodType, - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - EmailRecipientType, - ExternalRecipient, - MemberRecipient, - WebAppDeliveryMethod, - _WebAppDeliveryConfig, -) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now @@ -136,6 +136,26 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): entity.status_value = HumanInputFormStatus.SUBMITTED +def _build_human_input_node( + *, + node_id: str, + node_data: HumanInputNodeData | Mapping[str, Any], + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + runtime: DifyHumanInputNodeRuntime, +) -> HumanInputNode: + typed_node_data = ( + node_data if isinstance(node_data, HumanInputNodeData) else HumanInputNodeData.model_validate(node_data) + ) + return HumanInputNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + runtime=runtime, + ) + + class TestDeliveryMethod: """Test DeliveryMethod entity.""" @@ -239,7 +259,7 @@ class TestUserAction: data[field_name] = value with pytest.raises(ValidationError) as exc_info: - UserAction(**data) + UserAction.model_validate(data) errors = exc_info.value.errors() assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors) @@ -465,9 +485,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -530,9 +550,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -595,9 +615,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -671,9 +691,9 @@ class TestHumanInputNodeVariableResolution: runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, @@ -770,9 +790,9 @@ class TestHumanInputNodeRenderedContent: form_repository = InMemoryHumanInputFormRepository() runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] - node = HumanInputNode( - id=config["id"], - config=config, + node = _build_human_input_node( + node_id=config["id"], + node_data=config["data"], graph_init_params=graph_init_params, graph_runtime_state=runtime_state, runtime=runtime, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 52802c7ce1..4a9438b14f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,6 +1,9 @@ import datetime from types import SimpleNamespace +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes from graphon.graph_events import ( @@ -8,13 +11,10 @@ from graphon.graph_events import ( NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) +from graphon.nodes.human_input.entities import HumanInputNodeData from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode from graphon.runtime import GraphRuntimeState, VariablePool - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_runtime import DifyHumanInputNodeRuntime -from core.workflow.system_variables import default_system_variables from libs.datetime_utils import naive_utc_now @@ -26,6 +26,28 @@ class _FakeFormRepository: return self._form +def _create_human_input_node( + *, + config: dict, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + repo: _FakeFormRepository, +) -> HumanInputNode: + node_data = ( + config["data"] + if isinstance(config["data"], HumanInputNodeData) + else HumanInputNodeData.model_validate(config["data"]) + ) + return HumanInputNode( + node_id=config["id"], + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + ) + + def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( @@ -81,13 +103,11 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) @@ -146,13 +166,11 @@ def _build_timeout_node() -> HumanInputNode: ) repo = _FakeFormRepository(fake_form) - return HumanInputNode( - id="node-1", + return _create_human_input_node( config=config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=repo, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + repo=repo, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index bbfe350f7e..8ffce39cd6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -2,7 +2,10 @@ from collections.abc import Mapping from typing import Any import pytest + +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams +from graphon.nodes.iteration.entities import IterationNodeData from graphon.nodes.iteration.exc import IterationGraphNotFoundError from graphon.nodes.iteration.iteration_node import IterationNode from graphon.runtime import ( @@ -11,8 +14,6 @@ from graphon.runtime import ( GraphRuntimeState, VariablePool, ) - -from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -44,17 +45,14 @@ def _build_iteration_node( ) -> IterationNode: init_params = build_test_graph_init_params(graph_config=graph_config) return IterationNode( - id="iteration-node", - config={ - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration-node", "output"], - "start_node_id": start_node_id, - }, - }, + node_id="iteration-node", + config=IterationNodeData( + type="iteration", + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration-node", "output"], + start_node_id=start_node_id, + ), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index f8802138b5..f254fc3d09 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -3,9 +3,6 @@ import uuid from unittest.mock import Mock import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -19,6 +16,9 @@ from core.workflow.nodes.knowledge_index.protocols import ( SummaryIndexServiceProtocol, ) from core.workflow.system_variables import SystemVariableKey, build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -93,6 +93,25 @@ def sample_chunks(): } +def _build_node( + *, + node_id: str, + node_data: KnowledgeIndexNodeData | dict[str, object], + graph_init_params, + graph_runtime_state, +) -> KnowledgeIndexNode: + return KnowledgeIndexNode( + node_id=node_id, + config=( + node_data + if isinstance(node_data, KnowledgeIndexNodeData) + else KnowledgeIndexNodeData.model_validate(node_data) + ), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + class TestKnowledgeIndexNode: """ Test suite for KnowledgeIndexNode. @@ -115,9 +134,9 @@ class TestKnowledgeIndexNode: } # Act - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -143,9 +162,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -176,9 +195,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -212,9 +231,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -269,9 +288,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -332,9 +351,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -383,9 +402,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -440,9 +459,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -498,9 +517,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -536,9 +555,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -583,9 +602,9 @@ class TestKnowledgeIndexNode: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -623,9 +642,9 @@ class TestInvokeKnowledgeIndex: "data": sample_node_data.model_dump(), } - node = KnowledgeIndexNode( - id=node_id, - config=config, + node = _build_node( + node_id=node_id, + node_data=config["data"], graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index ab64be59ad..e923ee761b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -3,10 +3,6 @@ import uuid from unittest.mock import Mock import pytest -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -18,9 +14,17 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( SingleRetrievalConfig, ) from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import ( + KnowledgeRetrievalNode, + _normalize_metadata_filter_scalar, + _normalize_metadata_filter_sequence_item, +) from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source from core.workflow.system_variables import build_system_variables +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -85,6 +89,12 @@ def sample_node_data(): ) +def test_metadata_filter_normalizers_preserve_numeric_scalars_and_stringify_other_values() -> None: + assert _normalize_metadata_filter_scalar(3) == 3 + assert _normalize_metadata_filter_scalar(True) == "True" + assert _normalize_metadata_filter_sequence_item(4) == "4" + + class TestKnowledgeRetrievalNode: """ Test suite for KnowledgeRetrievalNode. @@ -106,8 +116,8 @@ class TestKnowledgeRetrievalNode: # Act node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -135,8 +145,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -194,8 +204,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -238,8 +248,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -274,8 +284,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -309,8 +319,8 @@ class TestKnowledgeRetrievalNode: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -350,8 +360,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -389,8 +399,8 @@ class TestKnowledgeRetrievalNode: mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -470,8 +480,8 @@ class TestFetchDatasetRetriever: config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -507,8 +517,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -562,8 +572,8 @@ class TestFetchDatasetRetriever: } node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -610,8 +620,8 @@ class TestFetchDatasetRetriever: mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme")) node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -671,8 +681,8 @@ class TestFetchDatasetRetriever: node_id = str(uuid.uuid4()) config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( - id=node_id, - config=config, + node_id=node_id, + config=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index fdf1706765..388654f279 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,23 +1,42 @@ +from types import SimpleNamespace from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.list_operator.entities import ListOperatorNodeData from graphon.nodes.list_operator.node import ListOperatorNode from graphon.runtime import GraphRuntimeState from graphon.variables import ArrayNumberSegment, ArrayStringSegment -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY - class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" + @staticmethod + def _build_node(*, config, graph_init_params, graph_runtime_state): + return ListOperatorNode( + node_id="test", + config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config), + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + @staticmethod + def _filter_by(comparison_operator: str, value: str) -> dict[str, object]: + return { + "enabled": True, + "conditions": [{"comparison_operator": comparison_operator, "value": value}], + } + @pytest.fixture def mock_graph_runtime_state(self): """Create mock GraphRuntimeState.""" mock_state = MagicMock(spec=GraphRuntimeState) mock_variable_pool = MagicMock() + mock_variable_pool.convert_template.side_effect = lambda value: SimpleNamespace(text=value) mock_state.variable_pool = mock_variable_pool return mock_state @@ -45,9 +64,8 @@ class TestListOperatorNode: def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable - return ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + return self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -64,9 +82,8 @@ class TestListOperatorNode: "limit": {"enabled": False}, } - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -109,9 +126,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=[]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -128,11 +144,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "contains", - "value": "app", - }, + "filter_by": self._filter_by("contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -140,9 +152,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -157,11 +168,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "not contains", - "value": "app", - }, + "filter_by": self._filter_by("not contains", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -169,9 +176,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "cherry"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -186,11 +192,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "5", - }, + "filter_by": self._filter_by(">", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -198,9 +200,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9, 11]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -226,9 +227,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -254,9 +254,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["cherry", "apple", "banana"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -282,9 +281,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "cherry", "date"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -299,11 +297,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": ">", - "value": "3", - }, + "filter_by": self._filter_by(">", "3"), "order_by": { "enabled": True, "value": "desc", @@ -317,9 +311,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 2, 3, 4, 5, 6, 7, 8, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -341,9 +334,8 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = None - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -366,9 +358,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["first", "middle", "last"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -384,11 +375,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "start with", - "value": "app", - }, + "filter_by": self._filter_by("start with", "app"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -396,9 +383,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "application", "banana", "apricot"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -413,11 +399,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "items"], - "filter_by": { - "enabled": True, - "condition": "end with", - "value": "le", - }, + "filter_by": self._filter_by("end with", "le"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -425,9 +407,8 @@ class TestListOperatorNode: mock_var = ArrayStringSegment(value=["apple", "banana", "pineapple", "table"]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -442,11 +423,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "=", - "value": "5", - }, + "filter_by": self._filter_by("=", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -454,9 +431,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -471,11 +447,7 @@ class TestListOperatorNode: config = { "title": "Test", "variable": ["sys", "numbers"], - "filter_by": { - "enabled": True, - "condition": "≠", - "value": "5", - }, + "filter_by": self._filter_by("≠", "5"), "order_by": {"enabled": False}, "limit": {"enabled": False}, } @@ -483,9 +455,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[1, 3, 5, 7, 9]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -511,9 +482,8 @@ class TestListOperatorNode: mock_var = ArrayNumberSegment(value=[9, 3, 7, 1, 5]) mock_graph_runtime_state.variable_pool.get.return_value = mock_var - node = ListOperatorNode( - id="test", - config={"id": "test", "data": config}, + node = self._build_node( + config=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index c784f805c0..212ad07bd3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -1,6 +1,8 @@ from unittest import mock import pytest + +from core.model_manager import ModelInstance from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities import ( ImagePromptMessageContent, @@ -33,8 +35,6 @@ from graphon.nodes.llm.exc import ( from graphon.runtime import VariablePool from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment -from core.model_manager import ModelInstance - def _build_model_schema( *, @@ -71,8 +71,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -95,6 +95,8 @@ def variable_pool() -> VariablePool: def _fetch_prompt_messages_with_mocked_content(content): variable_pool = VariablePool.empty() model_instance = mock.MagicMock(spec=ModelInstance) + model_schema = mock.MagicMock() + model_schema.supports_prompt_content_type.side_effect = lambda content_type: content_type == "text" prompt_template = [ LLMNodeChatModelMessage( text="You are a classifier.", @@ -106,7 +108,7 @@ def _fetch_prompt_messages_with_mocked_content(content): with ( mock.patch( "graphon.nodes.llm.llm_utils.fetch_model_schema", - return_value=mock.MagicMock(features=[]), + return_value=model_schema, ), mock.patch( "graphon.nodes.llm.llm_utils.handle_list_messages", diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 7841bf05ad..c707cf28cd 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,6 +5,19 @@ from collections.abc import Sequence from unittest import mock import pytest + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.common_entities import I18nObject @@ -67,19 +80,6 @@ from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from graphon.runtime import GraphRuntimeState, VariablePool from graphon.template_rendering import TemplateRenderError from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment - -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import ( - DifyCredentialsProvider, - DifyModelFactory, - build_dify_model_access, - fetch_model_config, -) -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.system_variables import default_system_variables from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params @@ -140,8 +140,8 @@ def _build_image_file( mime_type: str = "image/png", ) -> File: return File( - id=file_id, - type=FileType.IMAGE, + file_id=file_id, + file_type=FileType.IMAGE, filename=f"{file_id}{extension}", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=remote_url, @@ -205,14 +205,10 @@ def llm_node( mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -403,8 +399,8 @@ def test_dify_model_access_adapters_call_managers(): def test_fetch_files_with_file_segment(): file = File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -420,16 +416,16 @@ def test_fetch_files_with_file_segment(): def test_fetch_files_with_array_file_segment(): files = [ File( - id="1", - type=FileType.IMAGE, + file_id="1", + file_type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", storage_key="", ), File( - id="2", - type=FileType.IMAGE, + file_id="2", + file_type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="2", @@ -1174,14 +1170,10 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) - node_config = { - "id": "1", - "data": llm_node_data.model_dump(), - } http_client = mock.MagicMock() node = LLMNode( - id="1", - config=node_config, + node_id="1", + config=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -1203,8 +1195,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/png", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1233,8 +1225,8 @@ class TestLLMNodeSaveMultiModalImageOutput: mime_type="image/jpg", ) mock_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), filename="test-file.png", @@ -1291,8 +1283,8 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: image_b64_data = base64.b64encode(image_raw_data).decode() mock_saved_file = File( - id=str(uuid.uuid4()), - type=FileType.IMAGE, + file_id=str(uuid.uuid4()), + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", extension=".png", @@ -1457,7 +1449,6 @@ def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enable file_saver=file_saver, file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, reasoning_format="separated", ) ) @@ -1514,7 +1505,6 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=_build_prepared_llm_mock(), reasoning_format="separated", request_start_time=1.0, @@ -1552,7 +1542,6 @@ def test_handle_invoke_result_wraps_structured_output_parse_errors(): file_saver=mock.MagicMock(spec=LLMFileSaver), file_outputs=[], node_id="node-1", - node_type=LLMNode.node_type, model_instance=model_instance, ) ) diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 1c362a0a03..8f8ec49f14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any import pytest + +from factories.variable_factory import build_segment_with_type from graphon.model_runtime.entities import LLMMode from graphon.nodes.llm import ModelConfig, VisionConfig from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData @@ -18,8 +20,6 @@ from graphon.nodes.parameter_extractor.exc import ( from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from graphon.variables.types import SegmentType -from factories.variable_factory import build_segment_with_type - @dataclass class ValidTestCase: diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index d86e0efe02..892f6cc586 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus from graphon.graph import Graph from graphon.nodes.base.entities import VariableSelector @@ -8,11 +10,31 @@ from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode from graphon.runtime import GraphRuntimeState from graphon.template_rendering import TemplateRenderError - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params +def _build_template_transform_node( + *, + node_data, + graph_init_params, + graph_runtime_state, + node_id: str = "test_node", + **kwargs, +) -> TemplateTransformNode: + typed_node_data = ( + node_data + if isinstance(node_data, TemplateTransformNodeData) + else TemplateTransformNodeData.model_validate(node_data) + ) + return TemplateTransformNode( + node_id=node_id, + config=typed_node_data, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + **kwargs, + ) + + class TestTemplateTransformNode: """Comprehensive test suite for TemplateTransformNode.""" @@ -59,9 +81,8 @@ class TestTemplateTransformNode: def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test that TemplateTransformNode initializes correctly.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -75,9 +96,8 @@ class TestTemplateTransformNode: def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_title method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -88,9 +108,8 @@ class TestTemplateTransformNode: def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_description method.""" mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -108,9 +127,8 @@ class TestTemplateTransformNode: } mock_renderer = MagicMock() - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -143,9 +161,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() with pytest.raises(ValueError, match="max_output_length must be a positive integer"): - TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -170,9 +187,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -198,9 +214,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Value: " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -218,9 +233,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error") - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -238,9 +252,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -260,9 +273,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "1234567890" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": basic_node_data}, + node = _build_template_transform_node( + node_data=basic_node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -302,9 +314,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -375,8 +386,8 @@ class TestTemplateTransformNode: ) assert mapping == { - "node_123.var1": ["sys", "input1"], - "node_123.empty_selector": [], + "node_123.var1": ("sys", "input1"), + "node_123.empty_selector": (), } def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): @@ -409,9 +420,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "This is a static message." - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -448,9 +458,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Total: $31.5" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -477,9 +486,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com" - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, @@ -507,9 +515,8 @@ class TestTemplateTransformNode: mock_renderer = MagicMock() mock_renderer.render_template.return_value = "Tags: #python #ai #workflow " - node = TemplateTransformNode( - id="test_node", - config={"id": "test_node", "data": node_data}, + node = _build_template_transform_node( + node_data=node_data, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=mock_renderer, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index bd22a8e318..a846efbb43 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -1,14 +1,15 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.entities import TemplateTransformNodeData from graphon.nodes.template_transform.template_transform_node import ( DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, TemplateTransformNode, ) from graphon.runtime import GraphRuntimeState - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 @@ -37,15 +38,13 @@ def mock_graph_runtime_state(): def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): node = TemplateTransformNode( - id="test_node", - config={ - "id": "test_node", - "data": { - "title": "Template Transform", - "variables": [], - "template": "hello", - }, - }, + node_id="test_node", + config=TemplateTransformNodeData( + title="Template Transform", + type="template-transform", + variables=[], + template="hello", + ), graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, jinja2_template_renderer=MagicMock(), @@ -70,5 +69,5 @@ def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entri assert mapping == { "node_123.validated": ["sys", "input1"], - "node_123.raw": ["sys", "input2"], + "node_123.raw": ("sys", "input2"), } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index e11ebf6eb8..364408ead6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -1,16 +1,15 @@ from collections.abc import Mapping import pytest -from graphon.entities import GraphInitParams -from graphon.entities.base_node_data import BaseNodeData -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import BuiltinNodeTypes -from graphon.nodes.base.node import Node -from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_runtime import resolve_dify_run_context from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params @@ -42,17 +41,19 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - }, - } - ) +def _build_node_config() -> dict[str, object]: + return { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + ), + } + + +def _build_node_data() -> _SampleNodeData: + return _build_node_config()["data"] # type: ignore[return-value] def test_node_hydrates_data_during_initialization(): @@ -60,8 +61,8 @@ def test_node_hydrates_data_during_initialization(): init_params, runtime_state = _build_context(graph_config) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -86,8 +87,8 @@ def test_node_accepts_invoke_from_enum(): ) node = _SampleNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -117,13 +118,7 @@ def test_missing_generic_argument_raises_type_error(): def test_base_node_data_keeps_dict_style_access_compatibility(): - node_data = _SampleNodeData.model_validate( - { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - } - ) + node_data = _SampleNodeData(type=BuiltinNodeTypes.ANSWER, title="Sample", foo="bar") assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" @@ -133,21 +128,19 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): def test_node_hydration_preserves_compatibility_extra_fields(): graph_config: dict[str, object] = {} init_params, runtime_state = _build_context(graph_config) - node_config = NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": BuiltinNodeTypes.ANSWER, - "title": "Sample", - "foo": "bar", - "compat_flag": True, - }, - } - ) + node_config = { + "id": "node-1", + "data": _SampleNodeData( + type=BuiltinNodeTypes.ANSWER, + title="Sample", + foo="bar", + compat_flag=True, + ), + } node = _SampleNode( - id="node-1", - config=node_config, + node_id="node-1", + config=node_config["data"], graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 555ff0c945..dd75b32593 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -4,23 +4,25 @@ from unittest.mock import Mock, patch import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from graphon.entities import GraphInitParams from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod from graphon.node_events import NodeRunResult from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from graphon.nodes.document_extractor.exc import TextExtractionError, UnsupportedFileTypeError from graphon.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, + _extract_text_from_file, _extract_text_from_pdf, _extract_text_from_plain_text, _normalize_docx_zip, ) -from graphon.variables import ArrayFileSegment +from graphon.variables import ArrayFileSegment, FileSegment from graphon.variables.segments import ArrayStringSegment from graphon.variables.variables import StringVariable - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params @@ -44,11 +46,10 @@ def document_extractor_node(graph_init_params): title="Test Document Extractor", variable_selector=["node_id", "variable_name"], ) - node_config = {"id": "test_node_id", "data": node_data.model_dump()} http_client = Mock() node = DocumentExtractorNode( - id="test_node_id", - config=node_config, + node_id="test_node_id", + config=node_data, graph_init_params=graph_init_params, graph_runtime_state=Mock(), http_client=http_client, @@ -341,7 +342,7 @@ def test_extract_text_from_excel_sheet_parse_error(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["GoodSheet", "BadSheet"] - mock_excel_instance.parse.side_effect = [df, Exception("Parse error")] + mock_excel_instance.parse.side_effect = [df, TypeError("Parse error")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_mixed_content" @@ -386,7 +387,7 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): # Mock ExcelFile mock_excel_instance = Mock() mock_excel_instance.sheet_names = ["BadSheet1", "BadSheet2"] - mock_excel_instance.parse.side_effect = [Exception("Error 1"), Exception("Error 2")] + mock_excel_instance.parse.side_effect = [TypeError("Error 1"), TypeError("Error 2")] mock_excel_file.return_value = mock_excel_instance file_content = b"fake_excel_all_bad_sheets" @@ -397,6 +398,12 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): assert mock_excel_instance.parse.call_count == 2 +@patch("pandas.ExcelFile", side_effect=RuntimeError("broken workbook")) +def test_extract_text_from_excel_wraps_workbook_open_errors(mock_excel_file): + with pytest.raises(TextExtractionError, match="Failed to extract text from Excel file: broken workbook"): + _extract_text_from_excel(b"broken") + + @patch("pandas.ExcelFile") def test_extract_text_from_excel_numeric_type_column(mock_excel_file): """Test extracting text from Excel file with numeric column names.""" @@ -420,6 +427,103 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file): assert expected_manual == result +@pytest.mark.parametrize( + ("extension", "mime_type"), + [ + (".xlsx", "text/plain"), + (None, "application/vnd.ms-excel"), + ], +) +def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, extension, mime_type): + file = Mock(spec=File) + file.extension = extension + file.mime_type = mime_type + + with ( + patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"excel", + ), + patch( + "graphon.nodes.document_extractor.node._extract_text_from_excel", + return_value="excel text", + ) as mock_extract, + ): + result = _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + assert result == "excel text" + mock_extract.assert_called_once_with(b"excel") + + +def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node): + file = Mock(spec=File) + file.extension = None + file.mime_type = None + + with patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"unknown", + ): + with pytest.raises(UnsupportedFileTypeError, match="Unable to determine file type"): + _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + + +def test_run_list_file_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_list = Mock(spec=ArrayFileSegment) + file_list.value = [Mock(spec=File)] + mock_graph_runtime_state.variable_pool.get.return_value = file_list + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("bad file"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "bad file" + + +def test_run_single_file_segment_extraction_error_returns_failed(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + side_effect=TextExtractionError("single file failed"), + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "single file failed" + + +def test_run_single_file_segment_returns_string_output(document_extractor_node, mock_graph_runtime_state): + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + file_segment = Mock(spec=FileSegment) + file_segment.value = Mock(spec=File) + mock_graph_runtime_state.variable_pool.get.return_value = file_segment + + with patch( + "graphon.nodes.document_extractor.node._extract_text_from_file", + return_value="single file text", + ): + result = document_extractor_node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs == {"text": "single file text"} + + def _make_docx_zip(use_backslash: bool) -> bytes: """Helper to build a minimal in-memory DOCX zip. diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 1b14f0ab13..aa9a1360b0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,6 +3,11 @@ import uuid from unittest.mock import MagicMock, Mock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables +from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph @@ -11,14 +16,23 @@ from graphon.nodes.if_else.if_else_node import IfElseNode from graphon.runtime import GraphRuntimeState, VariablePool from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition from graphon.variables import ArrayFileSegment - -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params +def _build_if_else_node( + *, + node_data: IfElseNodeData | dict[str, object], + init_params, + graph_runtime_state, +) -> IfElseNode: + return IfElseNode( + node_id=str(uuid.uuid4()), + graph_init_params=init_params, + graph_runtime_state=graph_runtime_state, + config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data), + ) + + def test_execute_if_else_result_true(): graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} @@ -61,9 +75,8 @@ def test_execute_if_else_result_true(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "and", @@ -104,13 +117,8 @@ def test_execute_if_else_result_true(): {"comparison_operator": "not null", "variable_selector": ["start", "not_null"]}, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -155,9 +163,8 @@ def test_execute_if_else_result_false(): ) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - node_config = { - "id": "if-else", - "data": { + node = _build_if_else_node( + node_data={ "title": "123", "type": "if-else", "logical_operator": "or", @@ -174,13 +181,8 @@ def test_execute_if_else_result_false(): }, ], }, - } - - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_config, ) # Mock db.session.close() @@ -222,11 +224,6 @@ def test_array_file_contains_file_name(): ], ) - node_config = { - "id": "if-else", - "data": node_data.model_dump(), - } - # Create properly configured mock for graph_init_params graph_init_params = Mock() graph_init_params.workflow_id = "test_workflow" @@ -242,17 +239,12 @@ def test_array_file_contains_file_name(): } } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=graph_init_params, - graph_runtime_state=Mock(), - config=node_config, - ) + node = _build_if_else_node(node_data=node_data, init_params=graph_init_params, graph_runtime_state=Mock()) node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", filename="ab", @@ -334,11 +326,10 @@ def test_execute_if_else_boolean_conditions(condition: Condition): "logical_operator": "and", "conditions": [condition.model_dump()], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() @@ -400,14 +391,10 @@ def test_execute_if_else_boolean_false_conditions(): ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={ - "id": "if-else", - "data": node_data, - }, ) # Mock db.session.close() @@ -472,11 +459,10 @@ def test_execute_if_else_boolean_cases_structure(): } ], } - node = IfElseNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, + node = _build_if_else_node( + node_data=node_data, + init_params=init_params, graph_runtime_state=graph_runtime_state, - config={"id": "if-else", "data": node_data}, ) # Mock db.session.close() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index d28c3e01e5..465a4c0ff4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock import pytest + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from graphon.enums import WorkflowNodeExecutionStatus from graphon.file import File, FileTransferMethod, FileType from graphon.nodes.list_operator.entities import ( @@ -16,7 +18,14 @@ from graphon.nodes.list_operator.exc import InvalidKeyError from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from graphon.variables import ArrayFileSegment -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom + +def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode: + return ListOperatorNode( + node_id="test_node_id", + config=node_data, + graph_init_params=graph_init_params, + graph_runtime_state=MagicMock(), + ) @pytest.fixture @@ -35,10 +44,6 @@ def list_operator_node(): "title": "Test Title", } node_data = ListOperatorNodeData.model_validate(config) - node_config = { - "id": "test_node_id", - "data": node_data.model_dump(), - } # Create properly configured mock for graph_init_params graph_init_params = MagicMock() graph_init_params.workflow_id = "test_workflow" @@ -54,12 +59,7 @@ def list_operator_node(): } } - node = ListOperatorNode( - id="test_node_id", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=MagicMock(), - ) + node = _build_list_operator_node(node_data, graph_init_params) node.graph_runtime_state = MagicMock() node.graph_runtime_state.variable_pool = MagicMock() return node @@ -70,28 +70,28 @@ def test_filter_files_by_type(list_operator_node): files = [ File( filename="image1.jpg", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", ), File( filename="document1.pdf", - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", ), File( filename="image2.png", - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", ), File( filename="audio1.mp3", - type=FileType.AUDIO, + file_type=FileType.AUDIO, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -136,7 +136,7 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", extension=".txt", @@ -156,7 +156,7 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, extension=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 833c303052..5655f80737 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -2,16 +2,16 @@ import json import time import pytest +from pydantic import ValidationError as PydanticValidationError + +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from graphon.nodes.start.entities import StartNodeData from graphon.nodes.start.start_node import StartNode from graphon.runtime import GraphRuntimeState from graphon.variables import build_segment, segment_to_variable from graphon.variables.input_entities import VariableEntity, VariableEntityType from graphon.variables.variables import Variable -from pydantic import ValidationError as PydanticValidationError - -from core.workflow.system_variables import build_system_variables -from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -22,10 +22,7 @@ def make_start_node(user_inputs, variables): inputs=user_inputs, ) - config = { - "id": "start", - "data": StartNodeData(title="Start", variables=variables).model_dump(), - } + node_data = StartNodeData(title="Start", variables=variables) graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, @@ -33,8 +30,8 @@ def make_start_node(user_inputs, variables): ) return StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, @@ -109,7 +106,7 @@ def test_json_object_invalid_json_string(): node = make_start_node(user_inputs, variables) - with pytest.raises(ValueError, match="JSON object for 'profile' must be an object"): + with pytest.raises(TypeError, match="JSON object for 'profile' must be an object"): node._run() @@ -248,25 +245,22 @@ def test_start_node_outputs_full_variable_pool_snapshot(): inputs={"profile": {"age": 20, "name": "Tom"}}, ) - config = { - "id": "start", - "data": StartNodeData( - title="Start", - variables=[ - VariableEntity( - variable="profile", - label="profile", - type=VariableEntityType.JSON_OBJECT, - required=True, - ) - ], - ).model_dump(), - } + node_data = StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = StartNode( - id="start", - config=config, + node_id="start", + config=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 1587014802..284af68319 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -8,14 +8,15 @@ from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock import pytest + +from core.workflow.system_variables import build_system_variables from graphon.file import File, FileTransferMethod, FileType from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.segments import ArrayFileSegment - -from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only @@ -108,8 +109,8 @@ def tool_node(monkeypatch) -> ToolNode: runtime = _StubToolRuntime() node = ToolNode( - id="node-instance", - config=config, + node_id="node-instance", + config=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, @@ -118,13 +119,13 @@ def tool_node(monkeypatch) -> ToolNode: return node -def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: +def _collect_events(generator: Generator) -> list[Any]: events: list[Any] = [] try: while True: events.append(next(generator)) - except StopIteration as stop: - return events, stop.value + except StopIteration: + return events def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: @@ -135,12 +136,15 @@ def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[li node_id=tool_node._node_id, tool_runtime=ToolRuntimeHandle(raw=object()), ) - return _collect_events(generator) + events = _collect_events(generator) + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert completed_events + return events, completed_events[-1].node_run_result.llm_usage def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", @@ -195,7 +199,7 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): file_obj = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", filename="demo.pdf", diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py index c4dfc5a179..438af211f3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -6,11 +6,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from graphon.model_runtime.entities.llm_entities import LLMUsage -from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType -from graphon.nodes.tool.exc import ToolRuntimeInvocationError -from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage -from graphon.runtime import VariablePool from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError @@ -22,6 +17,11 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.system_variables import build_system_variables +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType +from graphon.nodes.tool.exc import ToolRuntimeInvocationError +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import VariablePool from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index 952e798430..e3b5e3b591 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,13 +1,12 @@ from collections.abc import Mapping -from graphon.entities import GraphInitParams -from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState - from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode from core.workflow.system_variables import build_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -28,29 +27,24 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state -def _build_node_config() -> NodeConfigDict: - return NodeConfigDictAdapter.validate_python( - { - "id": "node-1", - "data": { - "type": TRIGGER_PLUGIN_NODE_TYPE, - "title": "Trigger Event", - "plugin_id": "plugin-id", - "provider_id": "provider-id", - "event_name": "event-name", - "subscription_id": "subscription-id", - "plugin_unique_identifier": "plugin-unique-identifier", - "event_parameters": {}, - }, - } +def _build_node_data() -> TriggerEventNodeData: + return TriggerEventNodeData( + type=TRIGGER_PLUGIN_NODE_TYPE, + title="Trigger Event", + plugin_id="plugin-id", + provider_id="provider-id", + event_name="event-name", + subscription_id="subscription-id", + plugin_unique_identifier="plugin-unique-identifier", + event_parameters={}, ) def test_trigger_event_node_run_populates_trigger_info_metadata() -> None: init_params, runtime_state = _build_context(graph_config={}) node = TriggerEventNode( - id="node-1", - config=_build_node_config(), + node_id="node-1", + config=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index f1132af02b..617554ee17 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,5 +1,4 @@ import pytest -from graphon.entities.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, @@ -7,6 +6,7 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) +from graphon.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index cccd3fb676..07d03bec05 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -6,12 +6,9 @@ to FileVariable objects, fixing the "Invalid variable type: ObjectVariable" erro when passing files to downstream LLM nodes. """ +from typing import Any from unittest.mock import Mock, patch -from graphon.entities import GraphInitParams -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.runtime import GraphRuntimeState, VariablePool - from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -21,6 +18,9 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_variable_pool @@ -30,11 +30,6 @@ def create_webhook_node( tenant_id: str = "test-tenant", ) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "webhook-node-1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="test-workflow", graph_config={}, @@ -56,8 +51,8 @@ def create_webhook_node( ) node = TriggerWebhookNode( - id="webhook-node-1", - config=node_config, + node_id="webhook-node-1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -66,10 +61,6 @@ def create_webhook_node( runtime_state.app_config = Mock() runtime_state.app_config.tenant_id = tenant_id - # Provide compatibility alias expected by node implementation - # Some nodes reference `self.node_id`; expose it as an alias to `self.id` for tests - node.node_id = node.id - return node @@ -97,7 +88,7 @@ def create_test_file_dict( } -def build_webhook_variable_pool(inputs: dict) -> VariablePool: +def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool: return build_test_variable_pool( variables=default_system_variables(), node_id="webhook-node-1", @@ -105,7 +96,7 @@ def build_webhook_variable_pool(inputs: dict) -> VariablePool: ) -def expected_factory_mapping(file_dict: dict) -> dict: +def expected_factory_mapping(file_dict: dict[str, Any]) -> dict[str, Any]: return {**file_dict, "upload_file_id": file_dict["related_id"]} diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 34c66a4f9f..b839490d3c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,11 +1,7 @@ +from typing import Any from unittest.mock import patch import pytest -from graphon.entities import GraphInitParams -from graphon.enums import WorkflowNodeExecutionStatus -from graphon.file import File, FileTransferMethod, FileType -from graphon.runtime import GraphRuntimeState, VariablePool -from graphon.variables import FileVariable, StringVariable from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE @@ -18,16 +14,16 @@ from core.workflow.nodes.trigger_webhook.entities import ( ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode from core.workflow.system_variables import default_system_variables +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import FileVariable, StringVariable from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: """Helper function to create a webhook node with proper initialization.""" - node_config = { - "id": "1", - "data": webhook_data.model_dump(), - } - graph_init_params = GraphInitParams( workflow_id="1", graph_config={}, @@ -47,8 +43,8 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) start_at=0, ) node = TriggerWebhookNode( - id="1", - config=node_config, + node_id="1", + config=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -56,13 +52,10 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) # Provide tenant_id for conversion path runtime_state.app_config = type("_AppCfg", (), {"tenant_id": "1"})() - # Compatibility alias for some nodes referencing `self.node_id` - node.node_id = node.id - return node -def build_webhook_variable_pool(inputs: dict) -> VariablePool: +def build_webhook_variable_pool(inputs: dict[str, Any]) -> VariablePool: return build_test_variable_pool( variables=default_system_variables(), node_id="1", @@ -224,7 +217,7 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="image.jpg", @@ -233,7 +226,7 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - type=FileType.DOCUMENT, + file_type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", filename="document.pdf", @@ -268,8 +261,19 @@ def test_webhook_node_run_with_file_params(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() @@ -283,7 +287,7 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - type=FileType.IMAGE, + file_type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", filename="test.jpg", @@ -316,8 +320,19 @@ def test_webhook_node_run_mixed_parameters(): # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(*, mapping): - return File.model_validate(mapping) + def _to_file(*, mapping: dict[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + storage_key=mapping.get("storage_key", ""), + remote_url=mapping.get("url"), + ) mock_file_factory.side_effect = _to_file result = node._run() diff --git a/api/tests/unit_tests/core/workflow/test_human_input_adapter.py b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py new file mode 100644 index 0000000000..51049f8792 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_adapter.py @@ -0,0 +1,415 @@ +from types import SimpleNamespace + +import pytest +from pydantic import BaseModel + +from core.workflow.human_input_adapter import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + adapt_human_input_node_data_for_graph, + adapt_node_config_for_graph, + adapt_node_data_for_graph, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert " Team") - html = EmailDeliveryConfig.render_markdown_body( - "**Hello** [mail](mailto:test@example.com)" - ) - - assert rendered == "Open https://example.com and use 42" - assert sanitized == "Hello alert(1) Team" - assert "Hello" in html - assert " + ) +} + +export default memo(CreateAppAttributionBootstrap) diff --git a/web/app/components/custom/custom-page/__tests__/index.spec.tsx b/web/app/components/custom/custom-page/__tests__/index.spec.tsx index 5514d21b50..1f3655a9f8 100644 --- a/web/app/components/custom/custom-page/__tests__/index.spec.tsx +++ b/web/app/components/custom/custom-page/__tests__/index.spec.tsx @@ -1,9 +1,10 @@ +import type { ReactElement } from 'react' import type { AppContextValue } from '@/context/app-context' -import type { SystemFeatures } from '@/types/feature' -import { render, screen } from '@testing-library/react' +import { screen } from '@testing-library/react' import userEvent from '@testing-library/user-event' import { beforeEach, describe, expect, it, vi } from 'vitest' import { createMockProviderContextValue } from '@/__mocks__/provider-context' +import { renderWithSystemFeatures } from '@/__tests__/utils/mock-system-features' import { contactSalesUrl, defaultPlan } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' import { @@ -12,12 +13,19 @@ import { useAppContext, userProfilePlaceholder, } from '@/context/app-context' -import { useGlobalPublicStore } from '@/context/global-public-context' import { useModalContext } from '@/context/modal-context' import { useProviderContext } from '@/context/provider-context' -import { defaultSystemFeatures } from '@/types/feature' import CustomPage from '../index' +const render = (ui: ReactElement) => renderWithSystemFeatures(ui, { + systemFeatures: { + branding: { + enabled: true, + workspace_logo: 'https://example.com/workspace-logo.png', + }, + }, +}) + const { mockToast } = vi.hoisted(() => { const mockToast = Object.assign(vi.fn(), { success: vi.fn(), @@ -44,17 +52,13 @@ vi.mock('@/context/app-context', async (importOriginal) => { useAppContext: vi.fn(), } }) -vi.mock('@/context/global-public-context', () => ({ - useGlobalPublicStore: vi.fn(), -})) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: mockToast, })) const mockUseProviderContext = vi.mocked(useProviderContext) const mockUseModalContext = vi.mocked(useModalContext) const mockUseAppContext = vi.mocked(useAppContext) -const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore) const createProviderContext = ({ enableBilling = false, @@ -93,15 +97,6 @@ const createAppContextValue = (): AppContextValue => ({ isValidatingCurrentWorkspace: false, }) -const createSystemFeatures = (): SystemFeatures => ({ - ...defaultSystemFeatures, - branding: { - ...defaultSystemFeatures.branding, - enabled: true, - workspace_logo: 'https://example.com/workspace-logo.png', - }, -}) - describe('CustomPage', () => { const setShowPricingModal = vi.fn() @@ -113,10 +108,6 @@ describe('CustomPage', () => { setShowPricingModal, } as unknown as ReturnType) mockUseAppContext.mockReturnValue(createAppContextValue()) - mockUseGlobalPublicStore.mockImplementation(selector => selector({ - systemFeatures: createSystemFeatures(), - setSystemFeatures: vi.fn(), - })) }) // Integration coverage for the page and its child custom brand section. diff --git a/web/app/components/custom/custom-page/index.tsx b/web/app/components/custom/custom-page/index.tsx index 38f651882d..cd85a91230 100644 --- a/web/app/components/custom/custom-page/index.tsx +++ b/web/app/components/custom/custom-page/index.tsx @@ -20,7 +20,7 @@ const CustomPage = () => {
{t('upgradeTip.title', { ns: 'custom' })}
{t('upgradeTip.des', { ns: 'custom' })}
-
setShowPricingModal()}>{t('upgradeBtn.encourageShort', { ns: 'billing' })}
+
setShowPricingModal()}>{t('upgradeBtn.encourageShort', { ns: 'billing' })}
)} diff --git a/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx b/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx index b54b800049..1ea03c7bc6 100644 --- a/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx +++ b/web/app/components/custom/custom-web-app-brand/components/chat-preview-card.tsx @@ -1,5 +1,5 @@ -import Button from '@/app/components/base/button' -import { cn } from '@/utils/classnames' +import { Button } from '@langgenius/dify-ui/button' +import { cn } from '@langgenius/dify-ui/cn' import PoweredByBrand from './powered-by-brand' type ChatPreviewCardProps = { @@ -22,14 +22,14 @@ const ChatPreviewCard = ({
-
Chatflow App
+
Chatflow App
-
+
-
Hello! How can I assist you today?
+
Hello! How can I assist you today?
-
Talk to Dify
+
Talk to Dify
diff --git a/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx b/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx index 8a0feffbc4..51c0db53d7 100644 --- a/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx +++ b/web/app/components/custom/custom-web-app-brand/components/powered-by-brand.tsx @@ -20,7 +20,7 @@ const PoweredByBrand = ({ return ( <> -
POWERED BY
+
POWERED BY
{previewLogo ? logo : } diff --git a/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx b/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx index 573c2e9c1c..b4f30827bc 100644 --- a/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx +++ b/web/app/components/custom/custom-web-app-brand/components/workflow-preview-card.tsx @@ -1,5 +1,5 @@ -import Button from '@/app/components/base/button' -import { cn } from '@/utils/classnames' +import { Button } from '@langgenius/dify-ui/button' +import { cn } from '@langgenius/dify-ui/cn' import PoweredByBrand from './powered-by-brand' type WorkflowPreviewCardProps = { @@ -22,14 +22,14 @@ const WorkflowPreviewCard = ({
-
Workflow App
+
Workflow App
-
RUN ONCE
-
RUN BATCH
+
RUN ONCE
+
RUN BATCH
@@ -44,7 +44,7 @@ const WorkflowPreviewCard = ({
diff --git a/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx b/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx index a2730ffd27..99cbc03b32 100644 --- a/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx +++ b/web/app/components/custom/custom-web-app-brand/hooks/__tests__/use-web-app-brand.spec.tsx @@ -1,9 +1,10 @@ import type { ChangeEvent } from 'react' import type { AppContextValue } from '@/context/app-context' import type { SystemFeatures } from '@/types/feature' -import { act, renderHook } from '@testing-library/react' +import { act } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { createMockProviderContextValue } from '@/__mocks__/provider-context' +import { renderHookWithSystemFeatures } from '@/__tests__/utils/mock-system-features' import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' import { defaultPlan } from '@/app/components/billing/config' import { Plan } from '@/app/components/billing/type' @@ -13,12 +14,22 @@ import { useAppContext, userProfilePlaceholder, } from '@/context/app-context' -import { useGlobalPublicStore } from '@/context/global-public-context' import { useProviderContext } from '@/context/provider-context' import { updateCurrentWorkspace } from '@/service/common' -import { defaultSystemFeatures } from '@/types/feature' import useWebAppBrand from '../use-web-app-brand' +let currentBrandingOverrides: Partial = {} +const renderHook = (callback: (props: Props) => Result) => + renderHookWithSystemFeatures(callback, { + systemFeatures: { + branding: { + enabled: true, + workspace_logo: 'https://example.com/workspace-logo.png', + ...currentBrandingOverrides, + }, + }, + }) + const { mockNotify, mockToast } = vi.hoisted(() => { const mockNotify = vi.fn() const mockToast = Object.assign(mockNotify, { @@ -33,7 +44,7 @@ const { mockNotify, mockToast } = vi.hoisted(() => { return { mockNotify, mockToast } }) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: mockToast, })) vi.mock('@/service/common', () => ({ @@ -49,9 +60,6 @@ vi.mock('@/context/app-context', async (importOriginal) => { vi.mock('@/context/provider-context', () => ({ useProviderContext: vi.fn(), })) -vi.mock('@/context/global-public-context', () => ({ - useGlobalPublicStore: vi.fn(), -})) vi.mock('@/app/components/base/image-uploader/utils', () => ({ imageUpload: vi.fn(), getImageUploadErrorMessage: vi.fn(), @@ -60,7 +68,6 @@ vi.mock('@/app/components/base/image-uploader/utils', () => ({ const mockUpdateCurrentWorkspace = vi.mocked(updateCurrentWorkspace) const mockUseAppContext = vi.mocked(useAppContext) const mockUseProviderContext = vi.mocked(useProviderContext) -const mockUseGlobalPublicStore = vi.mocked(useGlobalPublicStore) const mockImageUpload = vi.mocked(imageUpload) const mockGetImageUploadErrorMessage = vi.mocked(getImageUploadErrorMessage) @@ -80,16 +87,6 @@ const createProviderContext = ({ }) } -const createSystemFeatures = (brandingOverrides: Partial = {}): SystemFeatures => ({ - ...defaultSystemFeatures, - branding: { - ...defaultSystemFeatures.branding, - enabled: true, - workspace_logo: 'https://example.com/workspace-logo.png', - ...brandingOverrides, - }, -}) - const createAppContextValue = (overrides: Partial = {}): AppContextValue => { const { currentWorkspace: currentWorkspaceOverride, ...restOverrides } = overrides const workspaceOverrides: Partial = currentWorkspaceOverride ?? {} @@ -122,21 +119,16 @@ const createAppContextValue = (overrides: Partial = {}): AppCon describe('useWebAppBrand', () => { let appContextValue: AppContextValue - let systemFeatures: SystemFeatures beforeEach(() => { vi.clearAllMocks() appContextValue = createAppContextValue() - systemFeatures = createSystemFeatures() + currentBrandingOverrides = {} mockUpdateCurrentWorkspace.mockResolvedValue(appContextValue.currentWorkspace) mockUseAppContext.mockImplementation(() => appContextValue) mockUseProviderContext.mockReturnValue(createProviderContext()) - mockUseGlobalPublicStore.mockImplementation(selector => selector({ - systemFeatures, - setSystemFeatures: vi.fn(), - })) mockGetImageUploadErrorMessage.mockReturnValue('upload error') }) @@ -174,10 +166,7 @@ describe('useWebAppBrand', () => { }) it('should fall back to an empty workspace logo when branding is disabled', () => { - systemFeatures = createSystemFeatures({ - enabled: false, - workspace_logo: '', - }) + currentBrandingOverrides = { enabled: false, workspace_logo: '' } const { result } = renderHook(() => useWebAppBrand()) diff --git a/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts b/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts index 44810de1fd..e24edab421 100644 --- a/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts +++ b/web/app/components/custom/custom-web-app-brand/hooks/use-web-app-brand.ts @@ -1,13 +1,14 @@ import type { ChangeEvent } from 'react' +import { toast } from '@langgenius/dify-ui/toast' +import { useSuspenseQuery } from '@tanstack/react-query' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { getImageUploadErrorMessage, imageUpload } from '@/app/components/base/image-uploader/utils' -import { toast } from '@/app/components/base/ui/toast' import { Plan } from '@/app/components/billing/type' import { useAppContext } from '@/context/app-context' -import { useGlobalPublicStore } from '@/context/global-public-context' import { useProviderContext } from '@/context/provider-context' import { updateCurrentWorkspace } from '@/service/common' +import { systemFeaturesQueryOptions } from '@/service/system-features' const MAX_LOGO_FILE_SIZE = 5 * 1024 * 1024 const CUSTOM_CONFIG_URL = '/workspaces/custom-config' @@ -19,7 +20,7 @@ const useWebAppBrand = () => { const [fileId, setFileId] = useState('') const [imgKey, setImgKey] = useState(() => Date.now()) const [uploadProgress, setUploadProgress] = useState(0) - const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) + const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) const isSandbox = enableBilling && plan.type === Plan.sandbox const uploading = uploadProgress > 0 && uploadProgress < 100 const webappLogo = currentWorkspace.custom_config?.replace_webapp_logo || '' diff --git a/web/app/components/custom/custom-web-app-brand/index.tsx b/web/app/components/custom/custom-web-app-brand/index.tsx index 0df5e043e9..53e6ff36ec 100644 --- a/web/app/components/custom/custom-web-app-brand/index.tsx +++ b/web/app/components/custom/custom-web-app-brand/index.tsx @@ -1,8 +1,8 @@ +import { Button } from '@langgenius/dify-ui/button' +import { cn } from '@langgenius/dify-ui/cn' +import { Switch } from '@langgenius/dify-ui/switch' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' -import Switch from '@/app/components/base/switch' -import { cn } from '@/utils/classnames' import ChatPreviewCard from './components/chat-preview-card' import WorkflowPreviewCard from './components/workflow-preview-card' import useWebAppBrand from './hooks/use-web-app-brand' @@ -31,19 +31,19 @@ const CustomWebAppBrand = () => { return (
-
+
{t('webapp.removeBrand', { ns: 'custom' })}
-
{t('webapp.changeLogo', { ns: 'custom' })}
-
{t('webapp.changeLogoTip', { ns: 'custom' })}
+
{t('webapp.changeLogo', { ns: 'custom' })}
+
{t('webapp.changeLogoTip', { ns: 'custom' })}
{(!uploadDisabled && webappLogo && !webappBrandRemoved) && ( @@ -64,7 +64,7 @@ const CustomWebAppBrand = () => { className="relative mr-2" disabled={uploadDisabled} > - + { (webappLogo || fileId) ? t('change', { ns: 'custom' }) @@ -87,7 +87,7 @@ const CustomWebAppBrand = () => { className="relative mr-2" disabled={true} > - + {t('uploading', { ns: 'custom' })} ) @@ -118,8 +118,8 @@ const CustomWebAppBrand = () => { {uploadProgress === -1 && (
{t('uploadedFail', { ns: 'custom' })}
)} -
-
{t('overview.appInfo.preview', { ns: 'appOverview' })}
+
+
{t('overview.appInfo.preview', { ns: 'appOverview' })}
diff --git a/web/app/components/datasets/chunk.tsx b/web/app/components/datasets/chunk.tsx index e0d820a4f3..947f2859c4 100644 --- a/web/app/components/datasets/chunk.tsx +++ b/web/app/components/datasets/chunk.tsx @@ -50,11 +50,11 @@ export const QAPreview: FC = (props) => { return (
- +

{qa.question}

- +

{qa.answer}

diff --git a/web/app/components/datasets/common/credential-icon.tsx b/web/app/components/datasets/common/credential-icon.tsx index a5eb6bfdb6..f993eca585 100644 --- a/web/app/components/datasets/common/credential-icon.tsx +++ b/web/app/components/datasets/common/credential-icon.tsx @@ -1,6 +1,6 @@ +import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' import { useCallback, useMemo, useState } from 'react' -import { cn } from '@/utils/classnames' type CredentialIconProps = { avatarUrl?: string @@ -59,7 +59,7 @@ export const CredentialIcon: React.FC = ({ )} style={{ width: `${size}px`, height: `${size}px` }} > - + {firstLetter}
diff --git a/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx b/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx index f8f0ce6e12..1251eab9fb 100644 --- a/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx +++ b/web/app/components/datasets/common/document-picker/__tests__/index.spec.tsx @@ -5,34 +5,7 @@ import * as React from 'react' import { ChunkingMode, DataSourceType } from '@/models/datasets' import DocumentPicker from '../index' -// Mock portal-to-follow-elem - always render content for testing -vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ - PortalToFollowElem: ({ children, open }: { - children: React.ReactNode - open?: boolean - }) => ( -
- {children} -
- ), - PortalToFollowElemTrigger: ({ children, onClick }: { - children: React.ReactNode - onClick?: () => void - }) => ( -
- {children} -
- ), - // Always render content to allow testing document selection - PortalToFollowElemContent: ({ children, className }: { - children: React.ReactNode - className?: string - }) => ( -
- {children} -
- ), -})) +vi.mock('@langgenius/dify-ui/popover', () => import('@/__mocks__/base-ui-popover')) // Mock useDocumentList hook with controllable return value let mockDocumentListData: { data: SimpleDocumentDetail[] } | undefined @@ -152,6 +125,10 @@ const renderComponent = (props: Partial { + fireEvent.click(screen.getByTestId('popover-trigger')) +} + describe('DocumentPicker', () => { beforeEach(() => { vi.clearAllMocks() @@ -165,7 +142,7 @@ describe('DocumentPicker', () => { it('should render without crashing', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should render document name when provided', () => { @@ -273,7 +250,7 @@ describe('DocumentPicker', () => { onChange, }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle value with all fields', () => { @@ -318,13 +295,13 @@ describe('DocumentPicker', () => { it('should initialize with popup closed', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + expect(screen.getByTestId('popover')).toHaveAttribute('data-open', 'false') }) it('should open popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) // Verify click handler is called @@ -430,7 +407,7 @@ describe('DocumentPicker', () => { ) // The component should use the new callback - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should memoize handleChange callback with useCallback', () => { @@ -440,7 +417,7 @@ describe('DocumentPicker', () => { renderComponent({ onChange }) // Verify component renders correctly, callback memoization is internal - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -518,7 +495,7 @@ describe('DocumentPicker', () => { it('should toggle popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) // Trigger click should be handled @@ -591,7 +568,7 @@ describe('DocumentPicker', () => { renderComponent() // When loading, component should still render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should fetch documents on mount', () => { @@ -611,7 +588,7 @@ describe('DocumentPicker', () => { renderComponent() // Component should render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle undefined data response', () => { @@ -620,7 +597,7 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -732,13 +709,13 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle rapid toggle clicks', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') // Rapid clicks fireEvent.click(trigger) @@ -795,7 +772,7 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle document list mapping with various data_source_detail_dict states', () => { @@ -819,7 +796,7 @@ describe('DocumentPicker', () => { renderComponent() // Should not crash during mapping - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -829,13 +806,13 @@ describe('DocumentPicker', () => { it('should handle empty datasetId', () => { renderComponent({ datasetId: '' }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle UUID format datasetId', () => { renderComponent({ datasetId: '123e4567-e89b-12d3-a456-426614174000' }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -926,6 +903,7 @@ describe('DocumentPicker', () => { const onChange = vi.fn() renderComponent({ onChange }) + openPopover() fireEvent.click(screen.getByText('Document 2')) @@ -939,6 +917,7 @@ describe('DocumentPicker', () => { mockDocumentListData = { data: docs } renderComponent() + openPopover() // Documents should be rendered in the list expect(screen.getByText('Document 1')).toBeInTheDocument() @@ -978,14 +957,14 @@ describe('DocumentPicker', () => { // The mapping: d.data_source_detail_dict?.upload_file?.extension || '' // Should extract 'pdf' from the document - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should render trigger with SearchInput integration', () => { renderComponent() // The trigger is always rendered - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + expect(screen.getByTestId('popover-trigger')).toBeInTheDocument() }) it('should integrate FileIcon component', () => { @@ -1001,7 +980,7 @@ describe('DocumentPicker', () => { }) // FileIcon should render an SVG icon for the file extension - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -1010,9 +989,10 @@ describe('DocumentPicker', () => { describe('Visual States', () => { it('should render portal content for document selection', () => { renderComponent() + openPopover() - // Portal content is rendered in our mock for testing - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Popover content is rendered after opening the trigger in our mock + expect(screen.getByTestId('popover-content')).toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx b/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx index 7178e9f60c..c7eb2c740c 100644 --- a/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx +++ b/web/app/components/datasets/common/document-picker/__tests__/preview-document-picker.spec.tsx @@ -3,34 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' import PreviewDocumentPicker from '../preview-document-picker' -// Mock portal-to-follow-elem - always render content for testing -vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ - PortalToFollowElem: ({ children, open }: { - children: React.ReactNode - open?: boolean - }) => ( -
- {children} -
- ), - PortalToFollowElemTrigger: ({ children, onClick }: { - children: React.ReactNode - onClick?: () => void - }) => ( -
- {children} -
- ), - // Always render content to allow testing document selection - PortalToFollowElemContent: ({ children, className }: { - children: React.ReactNode - className?: string - }) => ( -
- {children} -
- ), -})) +vi.mock('@langgenius/dify-ui/popover', () => import('@/__mocks__/base-ui-popover')) // Factory function to create mock DocumentItem const createMockDocumentItem = (overrides: Partial = {}): DocumentItem => ({ @@ -67,6 +40,10 @@ const renderComponent = (props: Partial { + fireEvent.click(screen.getByTestId('popover-trigger')) +} + describe('PreviewDocumentPicker', () => { beforeEach(() => { vi.clearAllMocks() @@ -77,7 +54,7 @@ describe('PreviewDocumentPicker', () => { it('should render without crashing', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should render document name from value prop', () => { @@ -110,7 +87,7 @@ describe('PreviewDocumentPicker', () => { files: [], // Use empty files to avoid duplicate icons }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) @@ -120,7 +97,7 @@ describe('PreviewDocumentPicker', () => { files: [], // Use empty files to avoid duplicate icons }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -131,22 +108,21 @@ describe('PreviewDocumentPicker', () => { const props = createDefaultProps() render() - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should apply className to trigger element', () => { renderComponent({ className: 'custom-class' }) - const trigger = screen.getByTestId('portal-trigger') - const innerDiv = trigger.querySelector('.custom-class') - expect(innerDiv).toBeInTheDocument() + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('custom-class') }) it('should handle empty files array', () => { // Component should render without crashing with empty files renderComponent({ files: [] }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle single file', () => { @@ -155,7 +131,7 @@ describe('PreviewDocumentPicker', () => { files: [createMockDocumentItem({ id: 'single-doc', name: 'Single File' })], }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle multiple files', () => { @@ -164,7 +140,7 @@ describe('PreviewDocumentPicker', () => { files: createMockDocumentList(5), }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should use value.extension for file icon', () => { @@ -172,7 +148,7 @@ describe('PreviewDocumentPicker', () => { value: createMockDocumentItem({ name: 'test.docx', extension: 'docx' }), }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -182,13 +158,13 @@ describe('PreviewDocumentPicker', () => { it('should initialize with popup closed', () => { renderComponent() - expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'false') + expect(screen.getByTestId('popover')).toHaveAttribute('data-open', 'false') }) it('should toggle popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) expect(trigger).toBeInTheDocument() @@ -196,9 +172,10 @@ describe('PreviewDocumentPicker', () => { it('should render portal content for document selection', () => { renderComponent() + openPopover() - // Portal content is always rendered in our mock for testing - expect(screen.getByTestId('portal-content')).toBeInTheDocument() + // Popover content is rendered after opening the trigger in our mock + expect(screen.getByTestId('popover-content')).toBeInTheDocument() }) }) @@ -242,7 +219,7 @@ describe('PreviewDocumentPicker', () => { , ) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -265,7 +242,7 @@ describe('PreviewDocumentPicker', () => { , ) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -274,7 +251,7 @@ describe('PreviewDocumentPicker', () => { it('should toggle popup when trigger is clicked', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') fireEvent.click(trigger) expect(trigger).toBeInTheDocument() @@ -283,6 +260,7 @@ describe('PreviewDocumentPicker', () => { it('should render document list with files', () => { const files = createMockDocumentList(3) renderComponent({ files }) + openPopover() // Documents should be visible in the list expect(screen.getByText('Document 1')).toBeInTheDocument() @@ -295,6 +273,7 @@ describe('PreviewDocumentPicker', () => { const files = createMockDocumentList(3) renderComponent({ files, onChange }) + openPopover() fireEvent.click(screen.getByText('Document 2')) @@ -306,7 +285,7 @@ describe('PreviewDocumentPicker', () => { it('should handle rapid toggle clicks', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') // Rapid clicks fireEvent.click(trigger) @@ -337,14 +316,14 @@ describe('PreviewDocumentPicker', () => { // Renders placeholder for missing name expect(screen.getByText('--')).toBeInTheDocument() // Portal wrapper renders - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle empty files array', () => { renderComponent({ files: [] }) // Component should render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle very long document names', () => { @@ -374,7 +353,7 @@ describe('PreviewDocumentPicker', () => { render() // Component should render without crashing - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle large number of files', () => { @@ -382,7 +361,7 @@ describe('PreviewDocumentPicker', () => { renderComponent({ files: manyFiles }) // Component should accept large files array - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle files with same name but different extensions', () => { @@ -393,7 +372,7 @@ describe('PreviewDocumentPicker', () => { renderComponent({ files }) // Component should handle duplicate names - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -427,7 +406,7 @@ describe('PreviewDocumentPicker', () => { files: [createMockDocumentItem({ name: 'Single' })], }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle two files', () => { @@ -435,7 +414,7 @@ describe('PreviewDocumentPicker', () => { files: createMockDocumentList(2), }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) it('should handle many files', () => { @@ -443,7 +422,7 @@ describe('PreviewDocumentPicker', () => { files: createMockDocumentList(50), }) - expect(screen.getByTestId('portal-elem')).toBeInTheDocument() + expect(screen.getByTestId('popover')).toBeInTheDocument() }) }) @@ -451,23 +430,22 @@ describe('PreviewDocumentPicker', () => { it('should apply custom className', () => { renderComponent({ className: 'my-custom-class' }) - const trigger = screen.getByTestId('portal-trigger') - expect(trigger.querySelector('.my-custom-class')).toBeInTheDocument() + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('my-custom-class') }) it('should work without className', () => { renderComponent({ className: undefined }) - expect(screen.getByTestId('portal-trigger')).toBeInTheDocument() + expect(screen.getByTestId('popover-trigger')).toBeInTheDocument() }) it('should handle multiple class names', () => { renderComponent({ className: 'class-one class-two' }) - const trigger = screen.getByTestId('portal-trigger') - const element = trigger.querySelector('.class-one') - expect(element).toBeInTheDocument() - expect(element).toHaveClass('class-two') + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('class-one') + expect(trigger).toHaveClass('class-two') }) }) @@ -480,7 +458,7 @@ describe('PreviewDocumentPicker', () => { files: [], // Use empty files to avoid duplicate icons }) - const trigger = screen.getByTestId('portal-trigger') + const trigger = screen.getByTestId('popover-trigger') expect(trigger.querySelector('svg')).toBeInTheDocument() }) }) @@ -491,6 +469,7 @@ describe('PreviewDocumentPicker', () => { it('should render all documents in the list', () => { const files = createMockDocumentList(5) renderComponent({ files }) + openPopover() // All documents should be visible files.forEach((file) => { @@ -503,6 +482,7 @@ describe('PreviewDocumentPicker', () => { const files = createMockDocumentList(3) renderComponent({ files, onChange }) + openPopover() fireEvent.click(screen.getByText('Document 1')) @@ -528,6 +508,7 @@ describe('PreviewDocumentPicker', () => { onChange={vi.fn()} />, ) + openPopover() expect(screen.getByText(/dataset\.preprocessDocument/)).toBeInTheDocument() }) }) @@ -537,9 +518,8 @@ describe('PreviewDocumentPicker', () => { it('should apply hover styles on trigger', () => { renderComponent() - const trigger = screen.getByTestId('portal-trigger') - const innerDiv = trigger.querySelector('.hover\\:bg-state-base-hover') - expect(innerDiv).toBeInTheDocument() + const trigger = screen.getByTestId('popover-trigger') + expect(trigger).toHaveClass('hover:bg-state-base-hover') }) it('should have truncate class for long names', () => { @@ -568,6 +548,7 @@ describe('PreviewDocumentPicker', () => { const files = createMockDocumentList(3) renderComponent({ files, onChange }) + openPopover() fireEvent.click(screen.getByText('Document 1')) @@ -582,10 +563,12 @@ describe('PreviewDocumentPicker', () => { ] renderComponent({ files: customFiles, onChange }) + openPopover() fireEvent.click(screen.getByText('Custom File 1')) expect(onChange).toHaveBeenCalledWith(customFiles[0]) + openPopover() fireEvent.click(screen.getByText('Custom File 2')) expect(onChange).toHaveBeenCalledWith(customFiles[1]) }) @@ -597,8 +580,11 @@ describe('PreviewDocumentPicker', () => { renderComponent({ files, onChange }) // Select multiple documents sequentially + openPopover() fireEvent.click(screen.getByText('Document 1')) + openPopover() fireEvent.click(screen.getByText('Document 3')) + openPopover() fireEvent.click(screen.getByText('Document 2')) expect(onChange).toHaveBeenCalledTimes(3) diff --git a/web/app/components/datasets/common/document-picker/document-list.tsx b/web/app/components/datasets/common/document-picker/document-list.tsx index 574792ee14..d2d8d1966c 100644 --- a/web/app/components/datasets/common/document-picker/document-list.tsx +++ b/web/app/components/datasets/common/document-picker/document-list.tsx @@ -1,9 +1,9 @@ 'use client' import type { FC } from 'react' import type { DocumentItem } from '@/models/datasets' +import { cn } from '@langgenius/dify-ui/cn' import * as React from 'react' import { useCallback } from 'react' -import { cn } from '@/utils/classnames' import FileIcon from '../document-file-icon' type Props = { diff --git a/web/app/components/datasets/common/document-picker/index.tsx b/web/app/components/datasets/common/document-picker/index.tsx index b564b4d2b1..0566b590de 100644 --- a/web/app/components/datasets/common/document-picker/index.tsx +++ b/web/app/components/datasets/common/document-picker/index.tsx @@ -1,6 +1,12 @@ 'use client' import type { FC } from 'react' import type { DocumentItem, ParentMode, SimpleDocumentDetail } from '@/models/datasets' +import { cn } from '@langgenius/dify-ui/cn' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@langgenius/dify-ui/popover' import { RiArrowDownSLine } from '@remixicon/react' import { useBoolean } from 'ahooks' import * as React from 'react' @@ -8,15 +14,9 @@ import { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { GeneralChunk, ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge' import Loading from '@/app/components/base/loading' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' import SearchInput from '@/app/components/base/search-input' import { ChunkingMode } from '@/models/datasets' import { useDocumentList } from '@/service/knowledge/use-document' -import { cn } from '@/utils/classnames' import FileIcon from '../document-file-icon' import DocumentList from './document-list' @@ -61,7 +61,6 @@ const DocumentPicker: FC = ({ const [open, { set: setOpen, - toggle: togglePopup, }] = useBoolean(false) const ArrowIcon = RiArrowDownSLine @@ -77,34 +76,40 @@ const DocumentPicker: FC = ({ }, [parentMode, t]) return ( - - -
- -
-
- - {' '} - {name || '--'} - - -
-
- - - {isGeneralMode && t('chunkingMode.general', { ns: 'dataset' })} - {isQAMode && t('chunkingMode.qa', { ns: 'dataset' })} - {isParentChild && `${t('chunkingMode.parentChild', { ns: 'dataset' })} · ${parentModeLabel}`} - + + +
+
+ + {' '} + {name || '--'} + + +
+
+ + + {isGeneralMode && t('chunkingMode.general', { ns: 'dataset' })} + {isQAMode && t('chunkingMode.qa', { ns: 'dataset' })} + {isParentChild && `${t('chunkingMode.parentChild', { ns: 'dataset' })} · ${parentModeLabel}`} + +
-
- - + )} + /> +
{documentsList @@ -125,9 +130,8 @@ const DocumentPicker: FC = ({
)}
- - -
+ + ) } export default React.memo(DocumentPicker) diff --git a/web/app/components/datasets/common/document-picker/preview-document-picker.tsx b/web/app/components/datasets/common/document-picker/preview-document-picker.tsx index dea0693983..597ceda9a5 100644 --- a/web/app/components/datasets/common/document-picker/preview-document-picker.tsx +++ b/web/app/components/datasets/common/document-picker/preview-document-picker.tsx @@ -1,18 +1,18 @@ 'use client' import type { FC } from 'react' import type { DocumentItem } from '@/models/datasets' +import { cn } from '@langgenius/dify-ui/cn' +import { + Popover, + PopoverContent, + PopoverTrigger, +} from '@langgenius/dify-ui/popover' import { RiArrowDownSLine } from '@remixicon/react' import { useBoolean } from 'ahooks' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' -import { cn } from '@/utils/classnames' import FileIcon from '../document-file-icon' import DocumentList from './document-list' @@ -35,7 +35,6 @@ const PreviewDocumentPicker: FC = ({ const [open, { set: setOpen, - toggle: togglePopup, }] = useBoolean(false) const ArrowIcon = RiArrowDownSLine @@ -45,29 +44,34 @@ const PreviewDocumentPicker: FC = ({ }, [onChange, setOpen]) return ( - - -
- -
-
- - {' '} - {name || '--'} - - + + +
+
+ + {' '} + {name || '--'} + + +
-
- - + )} + /> +
- {files?.length > 1 &&
{t('preprocessDocument', { ns: 'dataset', num: files.length })}
} + {files?.length > 1 &&
{t('preprocessDocument', { ns: 'dataset', num: files.length })}
} {files?.length > 0 ? ( = ({
)}
- - -
+ + ) } export default React.memo(PreviewDocumentPicker) diff --git a/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx b/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx index fcaca86e89..ba058860d3 100644 --- a/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx +++ b/web/app/components/datasets/common/document-status-with-action/__tests__/auto-disabled-document.spec.tsx @@ -1,6 +1,6 @@ +import { toast } from '@langgenius/dify-ui/toast' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { toast } from '@/app/components/base/ui/toast' import { useAutoDisabledDocuments } from '@/service/knowledge/use-document' import AutoDisabledDocument from '../auto-disabled-document' @@ -30,7 +30,7 @@ vi.mock('@/service/knowledge/use-document', () => ({ useInvalidDisabledDocument: vi.fn(() => mockInvalidDisabledDocument), })) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: { success: mockToastSuccess, }, diff --git a/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx index c6c7e03bd1..c150ac823f 100644 --- a/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx +++ b/web/app/components/datasets/common/document-status-with-action/auto-disabled-document.tsx @@ -1,9 +1,9 @@ 'use client' import type { FC } from 'react' +import { toast } from '@langgenius/dify-ui/toast' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { toast } from '@/app/components/base/ui/toast' import { useAutoDisabledDocuments, useDocumentEnable, useInvalidDisabledDocument } from '@/service/knowledge/use-document' import StatusWithAction from './status-with-action' diff --git a/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx b/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx index 002a2323a7..3c2536fa5e 100644 --- a/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx +++ b/web/app/components/datasets/common/document-status-with-action/status-with-action.tsx @@ -1,9 +1,9 @@ 'use client' import type { FC } from 'react' +import { cn } from '@langgenius/dify-ui/cn' import { RiAlertFill, RiCheckboxCircleFill, RiErrorWarningFill, RiInformation2Fill } from '@remixicon/react' import * as React from 'react' import Divider from '@/app/components/base/divider' -import { cn } from '@/utils/classnames' type Status = 'success' | 'error' | 'warning' | 'info' type Props = { @@ -46,7 +46,7 @@ const StatusAction: FC = ({ }) => { const { Icon, color } = getIcon(type) return ( -
+
{ it('should render without crashing', () => { const images = createMockImages(3) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should render all images when count is below limit', () => { @@ -87,7 +87,8 @@ describe('ImageList', () => { const images = createMockImages(15) render() // More button should be visible - expect(screen.getByText(/\+6/)).toBeInTheDocument() + // More button should be visible + expect(screen.getByText(/\+6/))!.toBeInTheDocument() }) }) @@ -97,33 +98,35 @@ describe('ImageList', () => { const { container } = render( , ) - expect(container.firstChild).toHaveClass('custom-class') + expect(container.firstChild)!.toHaveClass('custom-class') }) it('should use default limit of 9', () => { const images = createMockImages(12) render() // Should show "+3" for remaining images - expect(screen.getByText(/\+3/)).toBeInTheDocument() + // Should show "+3" for remaining images + expect(screen.getByText(/\+3/))!.toBeInTheDocument() }) it('should respect custom limit', () => { const images = createMockImages(10) render() // Should show "+5" for remaining images - expect(screen.getByText(/\+5/)).toBeInTheDocument() + // Should show "+5" for remaining images + expect(screen.getByText(/\+5/))!.toBeInTheDocument() }) it('should handle size prop sm', () => { const images = createMockImages(2) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should handle size prop md', () => { const images = createMockImages(2) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) }) @@ -135,6 +138,37 @@ describe('ImageList', () => { const moreButton = screen.getByText(/\+6/) fireEvent.click(moreButton) + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear + // More button should disappear // More button should disappear expect(screen.queryByText(/\+6/)).not.toBeInTheDocument() }) @@ -146,9 +180,10 @@ describe('ImageList', () => { // Find and click an image thumbnail const thumbnails = document.querySelectorAll('[class*="cursor-pointer"]') if (thumbnails.length > 0) { - fireEvent.click(thumbnails[0]) + fireEvent.click(thumbnails[0]!) // Preview should open - expect(screen.getByTestId('image-previewer')).toBeInTheDocument() + // Preview should open + expect(screen.getByTestId('image-previewer'))!.toBeInTheDocument() } }) @@ -159,12 +194,43 @@ describe('ImageList', () => { // Open preview const thumbnails = document.querySelectorAll('[class*="cursor-pointer"]') if (thumbnails.length > 0) { - fireEvent.click(thumbnails[0]) + fireEvent.click(thumbnails[0]!) // Close preview const closeButton = screen.getByTestId('close-preview') fireEvent.click(closeButton) + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed + // Preview should be closed // Preview should be closed expect(screen.queryByTestId('image-previewer')).not.toBeInTheDocument() } @@ -174,7 +240,7 @@ describe('ImageList', () => { describe('Edge Cases', () => { it('should handle empty images array', () => { const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should not open preview when clicked image not found in list (index === -1)', () => { @@ -185,7 +251,8 @@ describe('ImageList', () => { fireEvent.click(firstThumb) // Preview should open for valid image - expect(screen.getByTestId('image-previewer')).toBeInTheDocument() + // Preview should open for valid image + expect(screen.getByTestId('image-previewer'))!.toBeInTheDocument() // Close preview fireEvent.click(screen.getByTestId('close-preview')) @@ -197,7 +264,7 @@ describe('ImageList', () => { const validThumb = screen.getByTestId('file-thumb-https://example.com/image-1.png') fireEvent.click(validThumb) - expect(screen.getByTestId('image-previewer')).toBeInTheDocument() + expect(screen.getByTestId('image-previewer'))!.toBeInTheDocument() }) it('should return early when file sourceUrl is not found in limitedImages (index === -1)', () => { @@ -216,6 +283,37 @@ describe('ImageList', () => { }) } + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages + // Preview should NOT open because the file was not found in limitedImages // Preview should NOT open because the file was not found in limitedImages expect(screen.queryByTestId('image-previewer')).not.toBeInTheDocument() }) @@ -223,7 +321,7 @@ describe('ImageList', () => { it('should handle single image', () => { const images = createMockImages(1) const { container } = render() - expect(container.firstChild).toBeInTheDocument() + expect(container.firstChild)!.toBeInTheDocument() }) it('should not show More button when images count equals limit', () => { @@ -236,13 +334,45 @@ describe('ImageList', () => { const images = createMockImages(5) render() // Should show "+5" for all images - expect(screen.getByText(/\+5/)).toBeInTheDocument() + // Should show "+5" for all images + expect(screen.getByText(/\+5/))!.toBeInTheDocument() }) it('should handle limit larger than images count', () => { const images = createMockImages(5) render() // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button + // Should not show More button expect(screen.queryByText(/\+/)).not.toBeInTheDocument() }) }) diff --git a/web/app/components/datasets/common/image-list/index.tsx b/web/app/components/datasets/common/image-list/index.tsx index 48a990f94d..b04dd5afbd 100644 --- a/web/app/components/datasets/common/image-list/index.tsx +++ b/web/app/components/datasets/common/image-list/index.tsx @@ -1,8 +1,8 @@ import type { ImageInfo } from '../image-previewer' import type { FileEntity } from '@/app/components/base/file-thumb' +import { cn } from '@langgenius/dify-ui/cn' import { useCallback, useMemo, useState } from 'react' import FileThumb from '@/app/components/base/file-thumb' -import { cn } from '@/utils/classnames' import ImagePreviewer from '../image-previewer' import More from './more' diff --git a/web/app/components/datasets/common/image-list/more.tsx b/web/app/components/datasets/common/image-list/more.tsx index be6b53a5a5..255e5e4d87 100644 --- a/web/app/components/datasets/common/image-list/more.tsx +++ b/web/app/components/datasets/common/image-list/more.tsx @@ -32,7 +32,7 @@ const More = ({ count, onClick }: MoreProps) => {
-
+
) } diff --git a/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx b/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx index 44f0e95c08..1ff5cb33f8 100644 --- a/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx +++ b/web/app/components/datasets/common/image-previewer/__tests__/index.spec.tsx @@ -65,7 +65,8 @@ describe('ImagePreviewer', () => { }) // Should render in portal - expect(document.body.querySelector('.image-previewer')).toBeInTheDocument() + // Should render in portal + expect(document.body.querySelector('.image-previewer'))!.toBeInTheDocument() }) it('should render close button', async () => { @@ -77,7 +78,8 @@ describe('ImagePreviewer', () => { }) // Esc text should be visible - expect(screen.getByText('Esc')).toBeInTheDocument() + // Esc text should be visible + expect(screen.getByText('Esc'))!.toBeInTheDocument() }) it('should show loading state initially', async () => { @@ -92,7 +94,8 @@ describe('ImagePreviewer', () => { }) // Loading component should be visible - expect(document.body.querySelector('.image-previewer')).toBeInTheDocument() + // Loading component should be visible + expect(document.body.querySelector('.image-previewer'))!.toBeInTheDocument() }) }) @@ -107,7 +110,8 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Should start at second image - expect(screen.getByText('image2.png')).toBeInTheDocument() + // Should start at second image + expect(screen.getByText('image2.png'))!.toBeInTheDocument() }) }) @@ -120,7 +124,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) }) }) @@ -151,7 +155,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) // Find and click next button (right arrow) @@ -166,7 +170,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image2.png')).toBeInTheDocument() + expect(screen.getByText('image2.png'))!.toBeInTheDocument() }) } }) @@ -180,7 +184,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image2.png')).toBeInTheDocument() + expect(screen.getByText('image2.png'))!.toBeInTheDocument() }) // Find and click prev button (left arrow) @@ -195,7 +199,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) } }) @@ -213,7 +217,7 @@ describe('ImagePreviewer', () => { btn.className.includes('left-8'), ) - expect(prevButton).toBeDisabled() + expect(prevButton)!.toBeDisabled() }) it('should disable next button at last image', async () => { @@ -229,7 +233,7 @@ describe('ImagePreviewer', () => { btn.className.includes('right-8'), ) - expect(nextButton).toBeDisabled() + expect(nextButton)!.toBeDisabled() }) }) @@ -258,7 +262,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText(/Failed to load image/)).toBeInTheDocument() + expect(screen.getByText(/Failed to load image/))!.toBeInTheDocument() }) }) @@ -275,7 +279,7 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Retry button should be visible const retryButton = document.querySelector('button.rounded-full') - expect(retryButton).toBeInTheDocument() + expect(retryButton)!.toBeInTheDocument() }) }) }) @@ -290,7 +294,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) const buttons = document.querySelectorAll('button') @@ -306,7 +310,7 @@ describe('ImagePreviewer', () => { // Should still be at first image await waitFor(() => { - expect(screen.getByText('image1.png')).toBeInTheDocument() + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) } }) @@ -320,7 +324,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText('image3.png')).toBeInTheDocument() + expect(screen.getByText('image3.png'))!.toBeInTheDocument() }) const buttons = document.querySelectorAll('button') @@ -336,7 +340,7 @@ describe('ImagePreviewer', () => { // Should still be at last image await waitFor(() => { - expect(screen.getByText('image3.png')).toBeInTheDocument() + expect(screen.getByText('image3.png'))!.toBeInTheDocument() }) } }) @@ -366,7 +370,7 @@ describe('ImagePreviewer', () => { // Wait for error state await waitFor(() => { - expect(screen.getByText(/Failed to load image/)).toBeInTheDocument() + expect(screen.getByText(/Failed to load image/))!.toBeInTheDocument() }) const retryButton = document.querySelector('button.rounded-full') @@ -393,7 +397,7 @@ describe('ImagePreviewer', () => { }) await waitFor(() => { - expect(screen.getByText(/Failed to load image/)).toBeInTheDocument() + expect(screen.getByText(/Failed to load image/))!.toBeInTheDocument() }) // Find and click the retry button (not the nav buttons) @@ -402,7 +406,7 @@ describe('ImagePreviewer', () => { btn.className.includes('rounded-full') && !btn.className.includes('left-8') && !btn.className.includes('right-8'), ) - expect(retryButton).toBeInTheDocument() + expect(retryButton)!.toBeInTheDocument() if (retryButton) { mockFetch.mockClear() @@ -451,7 +455,7 @@ describe('ImagePreviewer', () => { describe('Edge Cases', () => { it('should handle single image', async () => { const onClose = vi.fn() - const images = [createMockImages()[0]] + const images = [createMockImages()[0]!] await act(async () => { render() @@ -466,8 +470,8 @@ describe('ImagePreviewer', () => { btn.className.includes('right-8'), ) - expect(prevButton).toBeDisabled() - expect(nextButton).toBeDisabled() + expect(prevButton)!.toBeDisabled() + expect(nextButton)!.toBeDisabled() }) it('should stop event propagation on container click', async () => { @@ -500,7 +504,8 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Should display dimensions (800 × 600 from MockImage) - expect(screen.getByText(/800.*600/)).toBeInTheDocument() + // Should display dimensions (800 × 600 from MockImage) + expect(screen.getByText(/800.*600/))!.toBeInTheDocument() }) }) @@ -514,7 +519,8 @@ describe('ImagePreviewer', () => { await waitFor(() => { // Should display formatted file size - expect(screen.getByText('image1.png')).toBeInTheDocument() + // Should display formatted file size + expect(screen.getByText('image1.png'))!.toBeInTheDocument() }) }) }) diff --git a/web/app/components/datasets/common/image-previewer/index.tsx b/web/app/components/datasets/common/image-previewer/index.tsx index 1164b4023a..42b4ce33a9 100644 --- a/web/app/components/datasets/common/image-previewer/index.tsx +++ b/web/app/components/datasets/common/image-previewer/index.tsx @@ -1,8 +1,8 @@ +import { Button } from '@langgenius/dify-ui/button' import { RiArrowLeftLine, RiArrowRightLine, RiCloseLine, RiRefreshLine } from '@remixicon/react' import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { createPortal } from 'react-dom' import { useHotkeys } from 'react-hotkeys-hook' -import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' import { formatFileSize } from '@/utils/format' @@ -137,7 +137,7 @@ const ImagePreviewer = ({ return { ...prev, [image.url]: { - ...prev[image.url], + ...prev[image.url]!, status: 'loading', }, } @@ -155,11 +155,11 @@ const ImagePreviewer = ({ onClick={e => e.stopPropagation()} tabIndex={-1} > -
+
- {cachedImages[currentImage.url].status === 'loading' && ( + {cachedImages[currentImage!.url]!.status === 'loading' && ( )} - {cachedImages[currentImage.url].status === 'error' && ( -
- {`Failed to load image: ${currentImage.url}. Please try again.`} + {cachedImages[currentImage!.url]!.status === 'error' && ( +
+ {`Failed to load image: ${currentImage!.url}. Please try again.`}
)} - {cachedImages[currentImage.url].status === 'loaded' && ( + {cachedImages[currentImage!.url]!.status === 'loaded' && (
{currentImage.name} -
- {currentImage.name} +
+ {currentImage!.name} · - {`${cachedImages[currentImage.url].width} ×  ${cachedImages[currentImage.url].height}`} + {`${cachedImages[currentImage!.url]!.width} ×  ${cachedImages[currentImage!.url]!.height}`} · - {formatFileSize(currentImage.size)} + {formatFileSize(currentImage!.size)}
)} ), })) -vi.mock('@/app/components/base/tooltip', () => ({ - default: ({ popupContent }: { popupContent: React.ReactNode }) => ( -
{popupContent}
- ), -})) - describe('RetrievalParamConfig', () => { const createDefaultConfig = (overrides?: Partial): RetrievalConfig => ({ search_method: RETRIEVE_METHOD.semantic, @@ -173,7 +167,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('rerank-switch')).toBeInTheDocument() + expect(screen.getByTestId('rerank-switch'))!.toBeInTheDocument() }) it('should render model selector when reranking is enabled', () => { @@ -186,7 +180,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('model-selector')).toBeInTheDocument() + expect(screen.getByTestId('model-selector'))!.toBeInTheDocument() }) it('should not render model selector when reranking is disabled', () => { @@ -212,8 +206,8 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('top-k-item')).toBeInTheDocument() - expect(screen.getByTestId('top-k-item')).toHaveAttribute('data-value', '5') + expect(screen.getByTestId('top-k-item'))!.toBeInTheDocument() + expect(screen.getByTestId('top-k-item'))!.toHaveAttribute('data-value', '5') }) it('should render score threshold item when reranking is enabled', () => { @@ -226,7 +220,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('score-threshold-item')).toBeInTheDocument() + expect(screen.getByTestId('score-threshold-item'))!.toBeInTheDocument() }) it('should toggle reranking enable', () => { @@ -349,7 +343,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip')).toBeInTheDocument() + expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip'))!.toBeInTheDocument() }) it('should not show multimodal tip when showMultiModalTip is false', () => { @@ -378,7 +372,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('rerank-switch')).toBeInTheDocument() + expect(screen.getByTestId('rerank-switch'))!.toBeInTheDocument() }) it('should hide score threshold when reranking is disabled for full text search', () => { @@ -410,7 +404,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('score-threshold-item')).toBeInTheDocument() + expect(screen.getByTestId('score-threshold-item'))!.toBeInTheDocument() }) }) @@ -451,7 +445,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('top-k-item')).toBeInTheDocument() + expect(screen.getByTestId('top-k-item'))!.toBeInTheDocument() }) it('should not render score threshold for keyword search', () => { @@ -496,7 +490,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('dataset.weightedScore.title')).toBeInTheDocument() + expect(screen.getByText('dataset.weightedScore.title'))!.toBeInTheDocument() }) it('should have RerankingModel option', () => { @@ -508,7 +502,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('common.modelProvider.rerankModel.key')).toBeInTheDocument() + expect(screen.getByText('common.modelProvider.rerankModel.key'))!.toBeInTheDocument() }) it('should show model selector when RerankingModel mode is selected', () => { @@ -520,7 +514,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('model-selector')).toBeInTheDocument() + expect(screen.getByTestId('model-selector'))!.toBeInTheDocument() }) it('should show WeightedScore component when WeightedScore mode is selected', () => { @@ -547,7 +541,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('weighted-score')).toBeInTheDocument() + expect(screen.getByTestId('weighted-score'))!.toBeInTheDocument() expect(screen.queryByTestId('model-selector')).not.toBeInTheDocument() }) @@ -565,7 +559,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(weightedScoreCard!) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.reranking_mode).toBe(RerankingModeEnum.WeightedScore) expect(calledWith.weights).toBeDefined() }) @@ -645,7 +639,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(screen.getByTestId('change-weights-btn')) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.weights.vector_setting.vector_weight).toBe(0.6) expect(calledWith.weights.keyword_setting.keyword_weight).toBe(0.4) }) @@ -659,8 +653,8 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('top-k-item')).toBeInTheDocument() - expect(screen.getByTestId('score-threshold-item')).toBeInTheDocument() + expect(screen.getByTestId('top-k-item'))!.toBeInTheDocument() + expect(screen.getByTestId('score-threshold-item'))!.toBeInTheDocument() }) it('should update top_k for hybrid search', () => { @@ -724,7 +718,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip')).toBeInTheDocument() + expect(screen.getByText('datasetSettings.form.retrievalSetting.multiModalTip'))!.toBeInTheDocument() }) it('should not show multimodal tip for hybrid search with WeightedScore', () => { @@ -799,7 +793,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByTestId('tooltip')).toBeInTheDocument() + expect(screen.getByLabelText('common.modelProvider.rerankModel.tip'))!.toBeInTheDocument() }) }) @@ -814,7 +808,7 @@ describe('RetrievalParamConfig', () => { />, ) - expect(screen.getByText('common.modelProvider.rerankModel.key')).toBeInTheDocument() + expect(screen.getByText('common.modelProvider.rerankModel.key'))!.toBeInTheDocument() }) }) @@ -838,7 +832,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(weightedScoreCard!) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.weights).toBeDefined() expect(calledWith.weights.weight_type).toBe(WeightedScoreEnum.Customized) }) @@ -872,7 +866,7 @@ describe('RetrievalParamConfig', () => { fireEvent.click(weightedScoreCard!) expect(mockOnChange).toHaveBeenCalled() - const calledWith = mockOnChange.mock.calls[0][0] + const calledWith = mockOnChange.mock.calls[0]![0] expect(calledWith.weights.vector_setting.vector_weight).toBe(0.8) }) }) diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index acb77acfa6..93392f2821 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -1,18 +1,19 @@ 'use client' import type { FC } from 'react' import type { RetrievalConfig } from '@/types/app' -import * as React from 'react' +import { cn } from '@langgenius/dify-ui/cn' +import { Switch } from '@langgenius/dify-ui/switch' +import { toast } from '@langgenius/dify-ui/toast' +import * as React from 'react' import { useCallback, useMemo } from 'react' import { useTranslation } from 'react-i18next' import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' +import { Infotip } from '@/app/components/base/infotip' import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' import TopKItem from '@/app/components/base/param-item/top-k-item' import RadioCard from '@/app/components/base/radio-card' -import Switch from '@/app/components/base/switch' -import Tooltip from '@/app/components/base/tooltip' -import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' @@ -22,7 +23,6 @@ import { WeightedScoreEnum, } from '@/models/datasets' import { RETRIEVE_METHOD } from '@/types/app' -import { cn } from '@/utils/classnames' import ProgressIndicator from '../../create/assets/progress-indicator.svg' import Reranking from '../../create/assets/rerank.svg' @@ -121,17 +121,18 @@ const RetrievalParamConfig: FC = ({ {canToggleRerankModalEnable && ( )}
- {t('modelProvider.rerankModel.key', { ns: 'common' })} - {t('modelProvider.rerankModel.tip', { ns: 'common' })}
- } - /> + {t('modelProvider.rerankModel.key', { ns: 'common' })} + + {t('modelProvider.rerankModel.tip', { ns: 'common' })} +
{ @@ -152,11 +153,11 @@ const RetrievalParamConfig: FC = ({ /> {showMultiModalTip && (
-
+
- + {t('form.retrievalSetting.multiModalTip', { ns: 'datasetSettings' })}
@@ -246,11 +247,11 @@ const RetrievalParamConfig: FC = ({ ...value.weights!, vector_setting: { ...value.weights!.vector_setting, - vector_weight: v.value[0], + vector_weight: v.value[0]!, }, keyword_setting: { ...value.weights!.keyword_setting, - keyword_weight: v.value[1], + keyword_weight: v.value[1]!, }, }, }) @@ -276,11 +277,11 @@ const RetrievalParamConfig: FC = ({ /> {showMultiModalTip && (
-
+
- + {t('form.retrievalSetting.multiModalTip', { ns: 'datasetSettings' })}
diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx index 3a0d5f6d63..b5c73f3422 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx @@ -52,7 +52,7 @@ const toastMocks = vi.hoisted(() => { } }) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: toastMocks.api, })) @@ -113,7 +113,7 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('app.importFromDSL')).toBeInTheDocument() + expect(screen.getByText('app.importFromDSL'))!.toBeInTheDocument() }) it('should not render modal content when show is false', () => { @@ -139,8 +139,8 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('app.importFromDSLFile')).toBeInTheDocument() - expect(screen.getByText('app.importFromDSLUrl')).toBeInTheDocument() + expect(screen.getByText('app.importFromDSLFile'))!.toBeInTheDocument() + expect(screen.getByText('app.importFromDSLUrl'))!.toBeInTheDocument() }) it('should render cancel and import buttons', () => { @@ -152,8 +152,8 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('app.newApp.Cancel')).toBeInTheDocument() - expect(screen.getByText('app.newApp.import')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Cancel'))!.toBeInTheDocument() + expect(screen.getByText('app.newApp.import'))!.toBeInTheDocument() }) it('should render uploader when file tab is active', () => { @@ -166,7 +166,7 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('app.dslUploader.button')).toBeInTheDocument() + expect(screen.getByText('app.dslUploader.button'))!.toBeInTheDocument() }) it('should render URL input when URL tab is active', () => { @@ -179,8 +179,8 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('DSL URL')).toBeInTheDocument() - expect(screen.getByPlaceholderText('app.importFromDSLUrlPlaceholder')).toBeInTheDocument() + expect(screen.getByText('DSL URL'))!.toBeInTheDocument() + expect(screen.getByPlaceholderText('app.importFromDSLUrlPlaceholder'))!.toBeInTheDocument() }) }) @@ -195,7 +195,8 @@ describe('CreateFromDSLModal', () => { ) // File tab content should be visible - expect(screen.getByText('app.dslUploader.button')).toBeInTheDocument() + // File tab content should be visible + expect(screen.getByText('app.dslUploader.button'))!.toBeInTheDocument() }) it('should use provided activeTab prop', () => { @@ -208,7 +209,7 @@ describe('CreateFromDSLModal', () => { { wrapper: createWrapper() }, ) - expect(screen.getByText('DSL URL')).toBeInTheDocument() + expect(screen.getByText('DSL URL'))!.toBeInTheDocument() }) it('should use provided dslUrl prop', () => { @@ -223,7 +224,7 @@ describe('CreateFromDSLModal', () => { ) const input = screen.getByPlaceholderText('app.importFromDSLUrlPlaceholder') - expect(input).toHaveValue('https://example.com/test.pipeline') + expect(input)!.toHaveValue('https://example.com/test.pipeline') }) it('should call onClose when cancel button is clicked', () => { @@ -252,12 +253,14 @@ describe('CreateFromDSLModal', () => { ) // Initially file tab is active - expect(screen.getByText('app.dslUploader.button')).toBeInTheDocument() + // Initially file tab is active + expect(screen.getByText('app.dslUploader.button'))!.toBeInTheDocument() fireEvent.click(screen.getByText('app.importFromDSLUrl')) // URL input should be visible - expect(screen.getByText('DSL URL')).toBeInTheDocument() + // URL input should be visible + expect(screen.getByText('DSL URL'))!.toBeInTheDocument() }) it('should update URL value when typing', () => { @@ -273,7 +276,7 @@ describe('CreateFromDSLModal', () => { const input = screen.getByPlaceholderText('app.importFromDSLUrlPlaceholder') fireEvent.change(input, { target: { value: 'https://example.com/test.pipeline' } }) - expect(input).toHaveValue('https://example.com/test.pipeline') + expect(input)!.toHaveValue('https://example.com/test.pipeline') }) it('should have disabled import button when no file is selected in file tab', () => { @@ -287,7 +290,7 @@ describe('CreateFromDSLModal', () => { ) const importButton = screen.getByText('app.newApp.import').closest('button') - expect(importButton).toBeDisabled() + expect(importButton)!.toBeDisabled() }) it('should have disabled import button when no URL is entered in URL tab', () => { @@ -301,7 +304,7 @@ describe('CreateFromDSLModal', () => { ) const importButton = screen.getByText('app.newApp.import').closest('button') - expect(importButton).toBeDisabled() + expect(importButton)!.toBeDisabled() }) it('should enable import button when URL is entered', () => { @@ -445,7 +448,7 @@ describe('CreateFromDSLModal', () => { }) await waitFor(() => { - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) vi.useRealTimers() @@ -663,14 +666,14 @@ describe('CreateFromDSLModal', () => { // File tab with no file - disabled let importButton = screen.getByText('app.newApp.import').closest('button') - expect(importButton).toBeDisabled() + expect(importButton)!.toBeDisabled() // Switch to URL tab by clicking on it fireEvent.click(screen.getByText('app.importFromDSLUrl')) // Still disabled (no URL) importButton = screen.getByText('app.newApp.import').closest('button') - expect(importButton).toBeDisabled() + expect(importButton)!.toBeDisabled() // Add URL value - should enable const input = screen.getByPlaceholderText('app.importFromDSLUrlPlaceholder') @@ -902,7 +905,7 @@ describe('CreateFromDSLModal', () => { // Wait for file to be displayed await waitFor(() => { - expect(screen.getByText('test.pipeline')).toBeInTheDocument() + expect(screen.getByText('test.pipeline'))!.toBeInTheDocument() }) // Now remove the file by clicking delete button (inside ActionButton) @@ -912,7 +915,7 @@ describe('CreateFromDSLModal', () => { fireEvent.click(deleteButton) // File should be removed - uploader prompt should show again await waitFor(() => { - expect(screen.getByText('app.dslUploader.button')).toBeInTheDocument() + expect(screen.getByText('app.dslUploader.button'))!.toBeInTheDocument() }) } }) @@ -968,7 +971,7 @@ describe('CreateFromDSLModal', () => { }) await waitFor(() => { - expect(screen.getByText('app.newApp.Confirm')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Confirm'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('app.newApp.Confirm')) @@ -1016,7 +1019,7 @@ describe('CreateFromDSLModal', () => { }) await waitFor(() => { - expect(screen.getByText('app.newApp.Confirm')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Confirm'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('app.newApp.Confirm')) @@ -1056,7 +1059,7 @@ describe('CreateFromDSLModal', () => { }) await waitFor(() => { - expect(screen.getByText('app.newApp.Confirm')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Confirm'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('app.newApp.Confirm')) @@ -1103,7 +1106,7 @@ describe('CreateFromDSLModal', () => { }) await waitFor(() => { - expect(screen.getByText('app.newApp.Confirm')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Confirm'))!.toBeInTheDocument() }) fireEvent.click(screen.getByText('app.newApp.Confirm')) @@ -1144,13 +1147,13 @@ describe('CreateFromDSLModal', () => { // Error modal should be visible await waitFor(() => { - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) // There are two Cancel buttons now (one in main modal footer, one in error modal) // Find the Cancel button in the error modal context const cancelButtons = screen.getAllByText('app.newApp.Cancel') - fireEvent.click(cancelButtons[cancelButtons.length - 1]) + fireEvent.click(cancelButtons[cancelButtons.length - 1]!) vi.useRealTimers() }) @@ -1166,14 +1169,14 @@ describe('Header', () => { describe('Rendering', () => { it('should render title', () => { render(
) - expect(screen.getByText('app.importFromDSL')).toBeInTheDocument() + expect(screen.getByText('app.importFromDSL'))!.toBeInTheDocument() }) it('should render close icon', () => { render(
) // Check for close icon container const closeButton = document.querySelector('[class*="cursor-pointer"]') - expect(closeButton).toBeInTheDocument() + expect(closeButton)!.toBeInTheDocument() }) }) @@ -1205,8 +1208,8 @@ describe('Tab', () => { />, ) - expect(screen.getByText('app.importFromDSLFile')).toBeInTheDocument() - expect(screen.getByText('app.importFromDSLUrl')).toBeInTheDocument() + expect(screen.getByText('app.importFromDSLFile'))!.toBeInTheDocument() + expect(screen.getByText('app.importFromDSLUrl'))!.toBeInTheDocument() }) }) @@ -1223,7 +1226,7 @@ describe('Tab', () => { fireEvent.click(screen.getByText('app.importFromDSLFile')) // Tab uses bind() which passes the key as first argument and event as second expect(setCurrentTab).toHaveBeenCalled() - expect(setCurrentTab.mock.calls[0][0]).toBe(CreateFromDSLModalTab.FROM_FILE) + expect(setCurrentTab.mock.calls[0]![0]).toBe(CreateFromDSLModalTab.FROM_FILE) }) it('should call setCurrentTab when clicking URL tab', () => { @@ -1238,7 +1241,7 @@ describe('Tab', () => { fireEvent.click(screen.getByText('app.importFromDSLUrl')) // Tab uses bind() which passes the key as first argument and event as second expect(setCurrentTab).toHaveBeenCalled() - expect(setCurrentTab.mock.calls[0][0]).toBe(CreateFromDSLModalTab.FROM_URL) + expect(setCurrentTab.mock.calls[0]![0]).toBe(CreateFromDSLModalTab.FROM_URL) }) }) }) @@ -1259,7 +1262,7 @@ describe('TabItem', () => { />, ) - expect(screen.getByText('Test Tab')).toBeInTheDocument() + expect(screen.getByText('Test Tab'))!.toBeInTheDocument() }) it('should render active indicator when active', () => { @@ -1273,7 +1276,7 @@ describe('TabItem', () => { // Active indicator is the bottom border div const indicator = document.querySelector('[class*="bg-util-colors-blue"]') - expect(indicator).toBeInTheDocument() + expect(indicator)!.toBeInTheDocument() }) it('should not render active indicator when inactive', () => { @@ -1348,8 +1351,8 @@ describe('Uploader', () => { />, ) - expect(screen.getByText('app.dslUploader.button')).toBeInTheDocument() - expect(screen.getByText('app.dslUploader.browse')).toBeInTheDocument() + expect(screen.getByText('app.dslUploader.button'))!.toBeInTheDocument() + expect(screen.getByText('app.dslUploader.browse'))!.toBeInTheDocument() }) it('should render file info when file is selected', () => { @@ -1362,8 +1365,8 @@ describe('Uploader', () => { />, ) - expect(screen.getByText('test.pipeline')).toBeInTheDocument() - expect(screen.getByText('PIPELINE')).toBeInTheDocument() + expect(screen.getByText('test.pipeline'))!.toBeInTheDocument() + expect(screen.getByText('PIPELINE'))!.toBeInTheDocument() }) it('should apply custom className', () => { @@ -1375,7 +1378,7 @@ describe('Uploader', () => { />, ) - expect(container.firstChild).toHaveClass('custom-class') + expect(container.firstChild)!.toHaveClass('custom-class') }) }) @@ -1484,7 +1487,8 @@ describe('Uploader', () => { dropArea.dispatchEvent(dragOverEvent) // Event should be handled without errors - expect(dropArea).toBeInTheDocument() + // Event should be handled without errors + expect(dropArea)!.toBeInTheDocument() }) it('should handle dragLeave event and reset dragging state when target is dragRef', async () => { @@ -1512,8 +1516,8 @@ describe('Uploader', () => { }) // The dragRef div appears when dragging is true - const dragRefDiv = document.querySelector('[class*="absolute left-0 top-0"]') - expect(dragRefDiv).toBeInTheDocument() + const dragRefDiv = document.querySelector('[class*="absolute top-0 left-0"]') + expect(dragRefDiv)!.toBeInTheDocument() // When dragLeave happens on the dragRef element, setDragging(false) is called if (dragRefDiv) { @@ -1648,7 +1652,7 @@ describe('Uploader', () => { // Get the file input const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement - expect(fileInput).toBeInTheDocument() + expect(fileInput)!.toBeInTheDocument() // Spy on input click before triggering selectHandle const clickSpy = vi.spyOn(fileInput, 'click').mockImplementation(() => { @@ -1688,7 +1692,7 @@ describe('Uploader', () => { }) // Now the dragRef div should exist - const dragRefDiv = document.querySelector('[class*="absolute left-0 top-0"]') + const dragRefDiv = document.querySelector('[class*="absolute top-0 left-0"]') // When dragEnter happens on dragRef itself, setDragging should NOT be called if (dragRefDiv) { @@ -1715,7 +1719,7 @@ describe('Uploader', () => { // Find and click delete button const deleteButton = document.querySelector('button') - expect(deleteButton).toBeInTheDocument() + expect(deleteButton)!.toBeInTheDocument() if (deleteButton) { fireEvent.click(deleteButton) @@ -1745,7 +1749,7 @@ describe('DSLConfirmModal', () => { />, ) - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) it('should render version information', () => { @@ -1760,8 +1764,8 @@ describe('DSLConfirmModal', () => { />, ) - expect(screen.getByText('0.9.0')).toBeInTheDocument() - expect(screen.getByText('1.0.0')).toBeInTheDocument() + expect(screen.getByText('0.9.0'))!.toBeInTheDocument() + expect(screen.getByText('1.0.0'))!.toBeInTheDocument() }) it('should render cancel and confirm buttons', () => { @@ -1772,8 +1776,8 @@ describe('DSLConfirmModal', () => { />, ) - expect(screen.getByText('app.newApp.Cancel')).toBeInTheDocument() - expect(screen.getByText('app.newApp.Confirm')).toBeInTheDocument() + expect(screen.getByText('app.newApp.Cancel'))!.toBeInTheDocument() + expect(screen.getByText('app.newApp.Confirm'))!.toBeInTheDocument() }) it('should render with default empty versions', () => { @@ -1785,7 +1789,8 @@ describe('DSLConfirmModal', () => { ) // Should not crash with default empty strings - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + // Should not crash with default empty strings + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) it('should disable confirm button when confirmDisabled is true', () => { @@ -1798,7 +1803,7 @@ describe('DSLConfirmModal', () => { ) const confirmButton = screen.getByText('app.newApp.Confirm').closest('button') - expect(confirmButton).toBeDisabled() + expect(confirmButton)!.toBeDisabled() }) }) @@ -1879,7 +1884,8 @@ describe('DSLConfirmModal', () => { ) // Component should render without crashing - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + // Component should render without crashing + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) it('should use default confirmDisabled when not provided', () => { @@ -1988,7 +1994,7 @@ describe('CreateFromDSLModal Integration', () => { // Verify error modal is shown await waitFor(() => { - expect(screen.getByText('app.newApp.appCreateDSLErrorTitle')).toBeInTheDocument() + expect(screen.getByText('app.newApp.appCreateDSLErrorTitle'))!.toBeInTheDocument() }) vi.useRealTimers() diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx index f553f491e5..f4263ab439 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/uploader.spec.tsx @@ -16,7 +16,7 @@ const { mockToast } = vi.hoisted(() => { return { mockToast } }) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: mockToast, })) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx index db675228c4..d0c97e185c 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx @@ -1,5 +1,5 @@ +import { Button } from '@langgenius/dify-ui/button' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' type DSLConfirmModalProps = { @@ -27,7 +27,7 @@ const DSLConfirmModal = ({ >
{t('newApp.appCreateDSLErrorTitle', { ns: 'app' })}
-
+
{t('newApp.appCreateDSLErrorPart1', { ns: 'app' })}
{t('newApp.appCreateDSLErrorPart2', { ns: 'app' })}

@@ -43,7 +43,7 @@ const DSLConfirmModal = ({
- +
) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/header.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/header.tsx index 5e4cb88c27..cf759880be 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/header.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/header.tsx @@ -12,10 +12,10 @@ const Header = ({ const { t } = useTranslation() return ( -
+
{t('importFromDSL', { ns: 'app' })}
diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx index 2a0f7f2cf0..b95fa4d5db 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/__tests__/use-dsl-import.spec.tsx @@ -49,7 +49,7 @@ const toastMocks = vi.hoisted(() => { } }) -vi.mock('@/app/components/base/ui/toast', () => ({ +vi.mock('@langgenius/dify-ui/toast', () => ({ toast: toastMocks.api, })) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts index 992d0526be..e085ffe1bc 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts @@ -1,8 +1,8 @@ 'use client' +import { toast } from '@langgenius/dify-ui/toast' import { useDebounceFn } from 'ahooks' import { useCallback, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' -import { toast } from '@/app/components/base/ui/toast' import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' import { DSLImportMode, DSLImportStatus } from '@/models/app' import { useRouter } from '@/next/navigation' diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx index 079ea90687..4f5b5c23fd 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx @@ -1,8 +1,8 @@ 'use client' +import { Button } from '@langgenius/dify-ui/button' import { useKeyPress } from 'ahooks' import { noop } from 'es-toolkit/function' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import DSLConfirmModal from './dsl-confirm-modal' @@ -78,7 +78,7 @@ const CreateFromDSLModal = ({ )} {currentTab === CreateFromDSLModalTab.FROM_URL && (
-
+
DSL URL
+
{ tabs.map(tab => ( = ({ file, updateFile, className }) => {
- {dragging &&
} + {dragging &&
}
)} {file && ( @@ -104,10 +104,10 @@ const Uploader: FC = ({ file, updateFile, className }) => {
- + {file.name} -
+
PIPELINE · {formatFileSize(file.size)} diff --git a/web/app/components/datasets/create-from-pipeline/footer.tsx b/web/app/components/datasets/create-from-pipeline/footer.tsx index ae1bb48394..dee48163d2 100644 --- a/web/app/components/datasets/create-from-pipeline/footer.tsx +++ b/web/app/components/datasets/create-from-pipeline/footer.tsx @@ -39,11 +39,11 @@ const Footer = () => { }, [invalidDatasetList]) return ( -
+
diff --git a/web/app/components/datasets/create-from-pipeline/index.tsx b/web/app/components/datasets/create-from-pipeline/index.tsx index 0c788f613d..f5b5c2ec8b 100644 --- a/web/app/components/datasets/create-from-pipeline/index.tsx +++ b/web/app/components/datasets/create-from-pipeline/index.tsx @@ -9,7 +9,7 @@ const CreateFromPipeline = () => {
- +