diff --git a/.agents/skills/component-refactoring/SKILL.md b/.agents/skills/component-refactoring/SKILL.md index 98a94592ab..a7cae67e8f 100644 --- a/.agents/skills/component-refactoring/SKILL.md +++ b/.agents/skills/component-refactoring/SKILL.md @@ -63,7 +63,7 @@ pnpm analyze-component --json ```typescript // ❌ Before: Complex state logic in component -const Configuration: FC = () => { +function Configuration() { const [modelConfig, setModelConfig] = useState(...) const [datasetConfigs, setDatasetConfigs] = useState(...) const [completionParams, setCompletionParams] = useState({}) @@ -85,7 +85,7 @@ export const useModelConfig = (appId: string) => { } // Component becomes cleaner -const Configuration: FC = () => { +function Configuration() { const { modelConfig, setModelConfig } = useModelConfig(appId) return
...
} @@ -189,8 +189,6 @@ const Template = useMemo(() => { **Dify Convention**: - This skill is for component decomposition, not query/mutation design. -- When refactoring data fetching, follow `web/AGENTS.md`. -- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling. - Do not introduce deprecated `useInvalid` / `useReset`. - Do not add thin passthrough `useQuery` wrappers during refactoring; only extract a custom hook when it truly orchestrates multiple queries/mutations or shared derived state. diff --git a/.agents/skills/component-refactoring/references/complexity-patterns.md b/.agents/skills/component-refactoring/references/complexity-patterns.md index 5a0a268f38..2873630d4b 100644 --- a/.agents/skills/component-refactoring/references/complexity-patterns.md +++ b/.agents/skills/component-refactoring/references/complexity-patterns.md @@ -60,8 +60,10 @@ const Template = useMemo(() => { **After** (complexity: ~3): ```typescript +import type { ComponentType } from 'react' + // Define lookup table outside component -const TEMPLATE_MAP: Record>> = { +const TEMPLATE_MAP: Record>> = { [AppModeEnum.CHAT]: { [LanguagesSupported[1]]: TemplateChatZh, [LanguagesSupported[7]]: TemplateChatJa, diff --git a/.agents/skills/component-refactoring/references/component-splitting.md b/.agents/skills/component-refactoring/references/component-splitting.md index 78a3389100..81c007e005 100644 --- a/.agents/skills/component-refactoring/references/component-splitting.md +++ b/.agents/skills/component-refactoring/references/component-splitting.md @@ -65,10 +65,10 @@ interface ConfigurationHeaderProps { onPublish: () => void } -const ConfigurationHeader: FC = ({ +function ConfigurationHeader({ isAdvancedMode, onPublish, -}) => { +}: ConfigurationHeaderProps) { const { t } = useTranslation() return ( @@ -136,7 +136,7 @@ const AppInfo = () => { } // ✅ After: Separate view components -const AppInfoExpanded: FC = ({ appDetail, onAction }) => { +function AppInfoExpanded({ appDetail, onAction }: AppInfoViewProps) { return (
{/* Clean, focused expanded view */} @@ -144,7 +144,7 @@ const AppInfoExpanded: FC = ({ appDetail, onAction }) => { ) } -const AppInfoCollapsed: FC = ({ appDetail, onAction }) => { +function AppInfoCollapsed({ appDetail, onAction }: AppInfoViewProps) { return (
{/* Clean, focused collapsed view */} @@ -203,12 +203,12 @@ interface AppInfoModalsProps { onSuccess: () => void } -const AppInfoModals: FC = ({ +function AppInfoModals({ appDetail, activeModal, onClose, onSuccess, -}) => { +}: AppInfoModalsProps) { const handleEdit = async (data) => { /* logic */ } const handleDuplicate = async (data) => { /* logic */ } const handleDelete = async () => { /* logic */ } @@ -296,7 +296,7 @@ interface OperationItemProps { onAction: (id: string) => void } -const OperationItem: FC = ({ operation, onAction }) => { +function OperationItem({ operation, onAction }: OperationItemProps) { return (
{operation.icon} @@ -435,7 +435,7 @@ interface ChildProps { onSubmit: () => void } -const Child: FC = ({ value, onChange, onSubmit }) => { +function Child({ value, onChange, onSubmit }: ChildProps) { return (
onChange(e.target.value)} /> diff --git a/.agents/skills/component-refactoring/references/hook-extraction.md b/.agents/skills/component-refactoring/references/hook-extraction.md index 0d567eb2a6..6fad2c8885 100644 --- a/.agents/skills/component-refactoring/references/hook-extraction.md +++ b/.agents/skills/component-refactoring/references/hook-extraction.md @@ -112,13 +112,13 @@ export const useModelConfig = ({ ```typescript // Before: 50+ lines of state management -const Configuration: FC = () => { +function Configuration() { const [modelConfig, setModelConfig] = useState(...) // ... lots of related state and effects } // After: Clean component -const Configuration: FC = () => { +function Configuration() { const { modelConfig, setModelConfig, @@ -159,8 +159,6 @@ const Configuration: FC = () => { When hook extraction touches query or mutation code, do not use this reference as the source of truth for data-layer patterns. -- Follow `web/AGENTS.md` first. -- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling. - Do not introduce deprecated `useInvalid` / `useReset`. - Do not extract thin passthrough `useQuery` hooks; only extract orchestration hooks. diff --git a/.agents/skills/e2e-cucumber-playwright/SKILL.md b/.agents/skills/e2e-cucumber-playwright/SKILL.md index de6b58f26d..dd7d204678 100644 --- a/.agents/skills/e2e-cucumber-playwright/SKILL.md +++ b/.agents/skills/e2e-cucumber-playwright/SKILL.md @@ -23,7 +23,7 @@ Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS - `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter 3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved. 4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved. -5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern. +5. Re-check official Playwright or Cucumber docs with the available documentation tools before introducing a new framework pattern. ## Local Rules diff --git a/.agents/skills/frontend-code-review/references/performance.md b/.agents/skills/frontend-code-review/references/performance.md index 2d60072f5c..0c33db46d0 100644 --- a/.agents/skills/frontend-code-review/references/performance.md +++ b/.agents/skills/frontend-code-review/references/performance.md @@ -9,18 +9,18 @@ Category: Performance When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks. -## Complex prop memoization +## Complex prop stability -IsUrgent: True +IsUrgent: False Category: Performance ### Description -Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders. +Only require stable object, array, or map props when there is a clear reason: the child is memoized, the value participates in effect/query dependencies, the value is part of a stable-reference API contract, or profiling/local behavior shows avoidable re-renders. Do not request `useMemo` for every inline object by default; `how-to-write-component` treats memoization as a targeted optimization. Update this file when adding, editing, or removing Performance rules so the catalog remains accurate. -Wrong: +Risky: ```tsx ``` -Right: +Better when stable identity matters: ```tsx const config = useMemo(() => ({ diff --git a/.agents/skills/frontend-query-mutation/SKILL.md b/.agents/skills/frontend-query-mutation/SKILL.md deleted file mode 100644 index 10c49d222e..0000000000 --- a/.agents/skills/frontend-query-mutation/SKILL.md +++ /dev/null @@ -1,46 +0,0 @@ ---- -name: frontend-query-mutation -description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers. ---- - -# Frontend Query & Mutation - -## Intent - -- Keep contract as the single source of truth in `web/contract/*`. -- Prefer contract-shaped `queryOptions()` and `mutationOptions()`. -- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough. -- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination. -- Keep abstractions minimal to preserve TypeScript inference. - -## Workflow - -1. Identify the change surface. - - Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape. - - Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations. - - Read both references when a task spans contract shape and runtime behavior. -2. Implement the smallest abstraction that fits the task. - - Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site. - - Extract a small shared query helper only when multiple call sites share the same extra options. - - Create or keep feature hooks only for real orchestration or shared domain behavior. - - When touching thin `web/service/use-*` wrappers, migrate them away when feasible. -3. Preserve Dify conventions. - - Keep contract inputs in `{ params, query?, body? }` shape. - - Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior. - - Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required. - -## Files Commonly Touched - -- `web/contract/console/*.ts` -- `web/contract/marketplace.ts` -- `web/contract/router.ts` -- `web/service/client.ts` -- legacy `web/service/use-*.ts` files when migrating wrappers away -- component and hook call sites using `consoleQuery` or `marketplaceQuery` - -## References - -- Use `references/contract-patterns.md` for contract shape, router registration, query and mutation helpers, and anti-patterns that degrade inference. -- Use `references/runtime-rules.md` for conditional queries, invalidation, `mutate` versus `mutateAsync`, and legacy migration rules. - -Treat this skill as the single query and mutation entry point for Dify frontend work. Keep detailed rules in the reference files instead of duplicating them in project docs. diff --git a/.agents/skills/frontend-query-mutation/agents/openai.yaml b/.agents/skills/frontend-query-mutation/agents/openai.yaml deleted file mode 100644 index 79e7e7d214..0000000000 --- a/.agents/skills/frontend-query-mutation/agents/openai.yaml +++ /dev/null @@ -1,4 +0,0 @@ -interface: - display_name: "Frontend Query & Mutation" - short_description: "Dify TanStack Query, oRPC, and default option patterns" - default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations." diff --git a/.agents/skills/frontend-query-mutation/references/contract-patterns.md b/.agents/skills/frontend-query-mutation/references/contract-patterns.md deleted file mode 100644 index 25ccfc81d7..0000000000 --- a/.agents/skills/frontend-query-mutation/references/contract-patterns.md +++ /dev/null @@ -1,129 +0,0 @@ -# Contract Patterns - -## Table of Contents - -- Intent -- Minimal structure -- Core workflow -- Query usage decision rule -- Mutation usage decision rule -- Thin hook decision rule -- Anti-patterns -- Contract rules -- Type export - -## Intent - -- Keep contract as the single source of truth in `web/contract/*`. -- Default query usage to call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract. -- Keep abstractions minimal and preserve TypeScript inference. - -## Minimal Structure - -```text -web/contract/ -├── base.ts -├── router.ts -├── marketplace.ts -└── console/ - ├── billing.ts - └── ...other domains -web/service/client.ts -``` - -## Core Workflow - -1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`. - - Use `base.route({...}).output(type<...>())` as the baseline. - - Add `.input(type<...>())` only when the request has `params`, `query`, or `body`. - - For `GET` without input, omit `.input(...)`; do not use `.input(type())`. -2. Register contract in `web/contract/router.ts`. - - Import directly from domain files and nest by API prefix. -3. Consume from UI call sites via oRPC query utilities. - -```typescript -import { useQuery } from '@tanstack/react-query' -import { consoleQuery } from '@/service/client' - -const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({ - staleTime: 5 * 60 * 1000, - throwOnError: true, - select: invoice => invoice.url, -})) -``` - -## Query Usage Decision Rule - -1. Default to direct `*.queryOptions(...)` usage at the call site. -2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook. -3. Create or keep feature hooks only for orchestration. - - Combine multiple queries or mutations. - - Share domain-level derived state or invalidation helpers. - - Prefer `web/features/{domain}/hooks/*` for feature-owned workflows. -4. Treat `web/service/use-{domain}.ts` as legacy. - - Do not create new thin service wrappers for oRPC contracts. - - When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough. - -```typescript -const invoicesBaseQueryOptions = () => - consoleQuery.billing.invoices.queryOptions({ retry: false }) - -const invoiceQuery = useQuery({ - ...invoicesBaseQueryOptions(), - throwOnError: true, -}) -``` - -## Mutation Usage Decision Rule - -1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`. -2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic. - -```typescript -const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions()) -``` - -## Thin Hook Decision Rule - -Remove thin hooks when they only rename a single oRPC query or mutation helper. -Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API. -Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`. - -Use: - -```typescript -const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions()) -``` - -Keep: - -```typescript -const applyTagBindingsMutation = useApplyTagBindingsMutation() -``` - -`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough. - -## Anti-Patterns - -- Do not wrap `useQuery` with `options?: Partial`. -- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case. -- Do not create thin `use-*` passthrough hooks for a single endpoint. -- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`. -- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs. -- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection. - -## Contract Rules - -- Input structure: always use `{ params, query?, body? }`. -- No-input `GET`: omit `.input(...)`; do not use `.input(type())`. -- Path params: use `{paramName}` in the path and match it in the `params` object. -- Router nesting: group by API prefix, for example `/billing/*` becomes `billing: {}`. -- No barrel files: import directly from specific files. -- Types: import from `@/types/` and use the `type()` helper. -- Mutations: prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults, filtering, and devtools. - -## Type Export - -```typescript -export type ConsoleInputs = InferContractRouterInputs -``` diff --git a/.agents/skills/frontend-query-mutation/references/runtime-rules.md b/.agents/skills/frontend-query-mutation/references/runtime-rules.md deleted file mode 100644 index 91b484d438..0000000000 --- a/.agents/skills/frontend-query-mutation/references/runtime-rules.md +++ /dev/null @@ -1,172 +0,0 @@ -# Runtime Rules - -## Table of Contents - -- Conditional queries -- oRPC default options -- Cache invalidation -- Key API guide -- `mutate` vs `mutateAsync` -- Legacy migration - -## Conditional Queries - -Prefer contract-shaped `queryOptions(...)`. -When required input is missing, prefer `input: skipToken` instead of placeholder params or non-null assertions. -Use `enabled` only for extra business gating after the input itself is already valid. - -```typescript -import { skipToken, useQuery } from '@tanstack/react-query' - -// Disable the query by skipping input construction. -function useAccessMode(appId: string | undefined) { - return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({ - input: appId - ? { params: { appId } } - : skipToken, - })) -} - -// Avoid runtime-only guards that bypass type checking. -function useBadAccessMode(appId: string | undefined) { - return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({ - input: { params: { appId: appId! } }, - enabled: !!appId, - })) -} -``` - -## oRPC Default Options - -Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation. - -Place defaults at the query utility creation point in `web/service/client.ts`: - -```typescript -export const consoleQuery = createTanstackQueryUtils(consoleClient, { - path: ['console'], - experimental_defaults: { - tags: { - create: { - mutationOptions: { - onSuccess: (tag, _variables, _result, context) => { - context.client.setQueryData( - consoleQuery.tags.list.queryKey({ - input: { - query: { - type: tag.type, - }, - }, - }), - (oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags, - ) - }, - }, - }, - }, - }, -}) -``` - -Rules: - -- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders. -- Do not create a wrapper function solely to host `createTanstackQueryUtils`. -- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`. -- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility. -- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation. - -## Cache Invalidation - -Bind shared invalidation in oRPC defaults when it is tied to a contract operation. -Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default. -Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate. - -Use: - -- `.key()` for namespace or prefix invalidation -- `.queryKey(...)` only for exact cache reads or writes such as `getQueryData` and `setQueryData` -- `queryClient.invalidateQueries(...)` in mutation `onSuccess` - -Do not use deprecated `useInvalid` from `use-base.ts`. - -```typescript -// Feature orchestration owns cache invalidation only when defaults are not enough. -export const useUpdateAccessMode = () => { - const queryClient = useQueryClient() - - return useMutation(consoleQuery.accessControl.updateAccessMode.mutationOptions({ - onSuccess: () => { - queryClient.invalidateQueries({ - queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(), - }) - }, - })) -} - -// Component only adds UI behavior. -updateAccessMode({ appId, mode }, { - onSuccess: () => toast.success('...'), -}) - -// Avoid putting invalidation knowledge in the component. -mutate({ appId, mode }, { - onSuccess: () => { - queryClient.invalidateQueries({ - queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(), - }) - }, -}) -``` - -## Key API Guide - -- `.key(...)` - - Use for partial matching operations. - - Prefer it for invalidation, refetch, and cancel patterns. - - Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })` -- `.queryKey(...)` - - Use for a specific query's full key. - - Prefer it for exact cache addressing and direct reads or writes. -- `.mutationKey(...)` - - Use for a specific mutation's full key. - - Prefer it for mutation defaults registration, mutation-status filtering, and devtools grouping. - -## `mutate` vs `mutateAsync` - -Prefer `mutate` by default. -Use `mutateAsync` only when Promise semantics are truly required, such as parallel mutations or sequential steps with result dependencies. - -Rules: - -- Event handlers should usually call `mutate(...)` with `onSuccess` or `onError`. -- Every `await mutateAsync(...)` must be wrapped in `try/catch`. -- Do not use `mutateAsync` when callbacks already express the flow clearly. - -```typescript -// Default case. -mutation.mutate(data, { - onSuccess: result => router.push(result.url), -}) - -// Promise semantics are required. -try { - const order = await createOrder.mutateAsync(orderData) - await confirmPayment.mutateAsync({ orderId: order.id, token }) - router.push(`/orders/${order.id}`) -} -catch (error) { - toast.error(error instanceof Error ? error.message : 'Unknown error') -} -``` - -## Legacy Migration - -When touching old code, migrate it toward these rules: - -| Old pattern | New pattern | -|---|---| -| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration | -| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook | -| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` | -| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` | diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 105c979c58..21c46d75bc 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -5,7 +5,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com # Dify Frontend Testing Skill -This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices. +This skill enables Codex to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices. > **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`). @@ -24,35 +24,27 @@ Apply this skill when the user: **Do NOT apply** when: - User is asking about backend/API tests (Python/pytest) -- User is asking about E2E tests (Playwright/Cypress) +- User is asking about E2E tests (Cucumber + Playwright under `e2e/`) - User is only asking conceptual questions without code context ## Quick Reference -### Tech Stack - -| Tool | Version | Purpose | -|------|---------|---------| -| Vitest | 4.0.16 | Test runner | -| React Testing Library | 16.0 | Component testing | -| jsdom | - | Test environment | -| nock | 14.0 | HTTP mocking | -| TypeScript | 5.x | Type safety | - ### Key Commands +Run these commands from `web/`. From the repository root, prefix them with `pnpm -C web`. + ```bash # Run all tests pnpm test # Watch mode -pnpm test:watch +pnpm test --watch # Run specific file pnpm test path/to/file.spec.tsx # Generate coverage report -pnpm test:coverage +pnpm test --coverage # Analyze component complexity pnpm analyze-component @@ -228,7 +220,10 @@ Every test should clearly separate: ### 2. Black-Box Testing - Test observable behavior, not implementation details -- Use semantic queries (getByRole, getByLabelText) +- Use semantic queries (`getByRole` with accessible `name`, `getByLabelText`, `getByPlaceholderText`, `getByText`, and scoped `within(...)`) +- Treat `getByTestId` as a last resort. If a control cannot be found by role/name, label, landmark, or dialog scope, fix the component accessibility first instead of adding or relying on `data-testid`. +- Remove production `data-testid` attributes when semantic selectors can cover the behavior. Keep them only for non-visual mocked boundaries, editor/browser shims such as Monaco, canvas/chart output, or third-party widgets with no accessible DOM in the test environment. +- Do not assert decorative icons by test id. Assert the named control that contains them, or mark decorative icons `aria-hidden`. - Avoid testing internal state directly - **Prefer pattern matching over hardcoded strings** in assertions: diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index 8c2f1c0c58..7723e4df21 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -56,7 +56,7 @@ See [Zustand Store Testing](#zustand-store-testing) section for full details. | Location | Purpose | |----------|---------| -| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `next/image`, `zustand`) | +| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `zustand`, clipboard, FloatingPortal, Monaco, localStorage`) | | `web/__mocks__/zustand.ts` | Zustand mock implementation (auto-resets stores after each test) | | `web/__mocks__/` | Reusable mock factories shared across multiple test files | | Test file | Test-specific mocks, inline with `vi.mock()` | @@ -216,28 +216,21 @@ describe('Component', () => { }) ``` -### 5. HTTP Mocking with Nock +### 5. HTTP and `fetch` Mocking ```typescript -import nock from 'nock' - -const GITHUB_HOST = 'https://api.github.com' -const GITHUB_PATH = '/repos/owner/repo' - -const mockGithubApi = (status: number, body: Record, delayMs = 0) => { - return nock(GITHUB_HOST) - .get(GITHUB_PATH) - .delay(delayMs) - .reply(status, body) -} - describe('GithubComponent', () => { - afterEach(() => { - nock.cleanAll() + beforeEach(() => { + vi.clearAllMocks() }) it('should display repo info', async () => { - mockGithubApi(200, { name: 'dify', stars: 1000 }) + vi.mocked(globalThis.fetch).mockResolvedValueOnce( + new Response(JSON.stringify({ name: 'dify', stars: 1000 }), { + status: 200, + headers: { 'Content-Type': 'application/json' }, + }), + ) render() @@ -247,7 +240,12 @@ describe('GithubComponent', () => { }) it('should handle API error', async () => { - mockGithubApi(500, { message: 'Server error' }) + vi.mocked(globalThis.fetch).mockResolvedValueOnce( + new Response(JSON.stringify({ message: 'Server error' }), { + status: 500, + headers: { 'Content-Type': 'application/json' }, + }), + ) render() @@ -258,6 +256,8 @@ describe('GithubComponent', () => { }) ``` +Prefer mocking `@/service/*` modules or spying on `global.fetch` / `ky` clients with deterministic responses. Do not introduce an HTTP interception dependency such as `nock` or MSW unless it is already declared in the workspace or adding it is part of the task. + ### 6. Context Providers ```typescript @@ -332,7 +332,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => { 1. **Don't mock Zustand store modules** - Use real stores with `setState()` 1. Don't mock components you can import directly 1. Don't create overly simplified mocks that miss conditional logic -1. Don't forget to clean up nock after each test +1. Don't leave HTTP mocks or service mock state leaking between tests 1. Don't use `any` types in mocks without necessity ### Mock Decision Tree diff --git a/.agents/skills/frontend-testing/references/workflow.md b/.agents/skills/frontend-testing/references/workflow.md index bc4ed8285a..27755d42a7 100644 --- a/.agents/skills/frontend-testing/references/workflow.md +++ b/.agents/skills/frontend-testing/references/workflow.md @@ -227,12 +227,12 @@ Failing tests compound: **Fix failures immediately before proceeding.** -## Integration with Claude's Todo Feature +## Integration with Codex's Todo Feature -When using Claude for multi-file testing: +When using Codex for multi-file testing: -1. **Ask Claude to create a todo list** before starting -1. **Request one file at a time** or ensure Claude processes incrementally +1. **Create a todo list** before starting +1. **Process one file at a time** 1. **Verify each test passes** before asking for the next 1. **Mark todos complete** as you progress diff --git a/.agents/skills/how-to-write-component/SKILL.md b/.agents/skills/how-to-write-component/SKILL.md new file mode 100644 index 0000000000..ac77112993 --- /dev/null +++ b/.agents/skills/how-to-write-component/SKILL.md @@ -0,0 +1,71 @@ +--- +name: how-to-write-component +description: React/TypeScript component style guide. Use when writing, refactoring, or reviewing React components, especially around props typing, state boundaries, shared local state with Jotai atoms, API types, query/mutation contracts, navigation, memoization, wrappers, and empty-state handling. +--- + +# How To Write A Component + +Use this as the decision guide for React/TypeScript component structure. Existing code is reference material, not automatic precedent; when it conflicts with these rules, adapt the approach instead of reproducing the violation. + +## Core Defaults + +- Search before adding UI, hooks, helpers, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit. +- Group code by feature workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them. +- Promote code to shared only when multiple verticals need the same stable primitive. Otherwise keep it local and compose shared primitives inside the owning feature. +- Use Tailwind CSS v4.1+ rules via the `tailwind-css-rules` skill. Prefer v4 utilities, `gap`, `text-size/line-height`, `min-h-dvh`, and avoid deprecated utilities and `@apply`. + +## Ownership + +- Put local state, queries, mutations, handlers, and derived UI data in the lowest component that uses them. Extract a purpose-built owner component only when the logic has no natural home. +- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data. Do not hoist a query only because it is duplicated; TanStack Query handles deduplication and cache sharing. +- Hoist state, queries, or callbacks to a parent only when the parent consumes the data, coordinates shared loading/error/empty UI, needs one consistent snapshot, or owns a workflow spanning children. +- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow. +- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action. +- Prefer uncontrolled DOM state and CSS variables before adding controlled props. + +## Components, Props, And Types + +- Type component signatures directly; do not use `FC` or `React.FC`. +- Prefer `function` for top-level components and module helpers. Use arrow functions for local callbacks, handlers, and lambda-style APIs. +- Prefer named exports. Use default exports only where the framework requires them, such as Next.js route files. +- Type simple one-off props inline. Use a named `Props` type only when reused, exported, complex, or clearer. +- Use API-generated or API-returned types at component boundaries. Keep small UI conversion helpers beside the component that needs them. +- Name values by their domain role and backend API contract, and keep that name stable across the call chain, especially IDs like `appInstanceId`. Normalize framework or route params at the boundary. +- Keep fallback and invariant checks at the lowest component that already handles that state; callers should pass raw values through instead of duplicating checks. + +## Queries And Mutations + +- Keep `web/contract/*` as the single source of truth for API shape; follow existing domain/router patterns and the `{ params, query?, body? }` input shape. +- Consume queries directly with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`. +- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename `queryOptions()` or `mutationOptions()`. Extract a small `queryOptions` helper only when repeated call-site options justify it. +- Keep feature hooks for real orchestration, workflow state, or shared domain behavior. +- For missing required query input, use `input: skipToken`; use `enabled` only for extra business gating after the input is valid. +- Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))`; use oRPC clients as `mutationFn` only for custom flows. +- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`; components may add UI feedback callbacks, but should not own shared invalidation rules. +- Do not use deprecated `useInvalid` or `useReset`. +- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required, and wrap awaited calls in `try/catch`. + +## Component Boundaries + +- Use the first level below a page or tab to organize independent page sections when it adds real structure. This layer is layout/semantic first, not automatically the data owner. +- Split deeper components by the data and state each layer actually needs. Each component should access only necessary data, and ownership should stay at the lowest consumer. +- Keep cohesive forms, menu bodies, and one-off helpers local unless they need their own state, reuse, or semantic boundary. +- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow. +- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment. +- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible. +- Avoid shallow wrappers and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. + +## You Might Not Need An Effect + +- Use Effects only to synchronize with external systems such as browser APIs, non-React widgets, subscriptions, timers, analytics that must run because the component was shown, or imperative DOM integration. +- Do not use Effects to transform props or state for rendering. Calculate derived values during render, and use `useMemo` only when the calculation is actually expensive. +- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known. +- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render. +- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary. +- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components. +- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow. + +## Navigation And Performance + +- Prefer `Link` for normal navigation. Use router APIs only for command-flow side effects such as mutation success, guarded redirects, or form submission. +- Avoid `memo`, `useMemo`, and `useCallback` unless there is a clear performance reason. diff --git a/.agents/skills/tailwind-css-rules/SKILL.md b/.agents/skills/tailwind-css-rules/SKILL.md new file mode 100644 index 0000000000..3528548036 --- /dev/null +++ b/.agents/skills/tailwind-css-rules/SKILL.md @@ -0,0 +1,367 @@ +--- +name: tailwind-css-rules +description: Tailwind CSS v4.1+ rules and best practices. Use when writing, reviewing, refactoring, or upgrading Tailwind CSS classes and styles, especially v4 utility migrations, layout spacing, typography, responsive variants, dark mode, gradients, CSS variables, and component styling. +--- + +# Tailwind CSS Rules and Best Practices + +## Core Principles + +- **Always use Tailwind CSS v4.1+** - Ensure the codebase is using the latest version +- **Do not use deprecated or removed utilities** - ALWAYS use the replacement +- **Never use `@apply`** - Use CSS variables, the `--spacing()` function, or framework components instead +- **Check for redundant classes** - Remove any classes that aren't necessary +- **Group elements logically** to simplify responsive tweaks later + +## Upgrading to Tailwind CSS v4 + +### Before Upgrading + +- **Always read the upgrade documentation first** - Read https://tailwindcss.com/docs/upgrade-guide and https://tailwindcss.com/blog/tailwindcss-v4 before starting an upgrade. +- Ensure the git repository is in a clean state before starting + +### Upgrade Process + +1. Run the upgrade command: `npx @tailwindcss/upgrade@latest` for both major and minor updates +2. The tool will convert JavaScript config files to the new CSS format +3. Review all changes extensively to clean up any false positives +4. Test thoroughly across your application + +## Breaking Changes Reference + +### Removed Utilities (NEVER use these in v4) + +| ❌ Deprecated | ✅ Replacement | +| ----------------------- | ------------------------------------------------- | +| `bg-opacity-*` | Use opacity modifiers like `bg-black/50` | +| `text-opacity-*` | Use opacity modifiers like `text-black/50` | +| `border-opacity-*` | Use opacity modifiers like `border-black/50` | +| `divide-opacity-*` | Use opacity modifiers like `divide-black/50` | +| `ring-opacity-*` | Use opacity modifiers like `ring-black/50` | +| `placeholder-opacity-*` | Use opacity modifiers like `placeholder-black/50` | +| `flex-shrink-*` | `shrink-*` | +| `flex-grow-*` | `grow-*` | +| `overflow-ellipsis` | `text-ellipsis` | +| `decoration-slice` | `box-decoration-slice` | +| `decoration-clone` | `box-decoration-clone` | + +### Renamed Utilities + +Use the v4 name when migrating code that still carries Tailwind v3 semantics. Do not blanket-replace existing v4 classes: classes such as `rounded-sm`, `shadow-sm`, `ring-1`, and `ring-2` are valid in this codebase when they intentionally represent the current design scale. + +| ❌ v3 pattern | ✅ v4 pattern | +| ------------------- | -------------------------------------------------- | +| `bg-gradient-*` | `bg-linear-*` | +| old shadow scale | verify against the current Tailwind/design scale | +| old blur scale | verify against the current Tailwind/design scale | +| old radius scale | use the Dify radius token mapping when applicable | +| `outline-none` | `outline-hidden` | +| bare `ring` utility | use an explicit ring width such as `ring-1`/`ring-2`/`ring-3` | + +For Figma radius tokens, follow `packages/dify-ui/AGENTS.md`. For example, `--radius/xs` maps to `rounded-sm`; do not rewrite it to `rounded-xs`. + +## Layout and Spacing Rules + +### Flexbox and Grid Spacing + +#### Always use gap utilities for internal spacing + +Gap provides consistent spacing without edge cases (no extra space on last items). It's cleaner and more maintainable than margins on children. + +```html + +
+
Item 1
+
Item 2
+
Item 3
+ +
+ + +
+
Item 1
+
Item 2
+
Item 3
+
+``` + +#### Gap vs Space utilities + +- **Never use `space-x-*` or `space-y-*` in flex/grid layouts** - always use gap +- Space utilities add margins to children and have issues with wrapped items +- Gap works correctly with flex-wrap and all flex directions + +```html + +
+ +
+ + +
+ +
+``` + +### General Spacing Guidelines + +- **Prefer top and left margins** over bottom and right margins (unless conditionally rendered) +- **Use padding on parent containers** instead of bottom margins on the last child +- **Always use `min-h-dvh` instead of `min-h-screen`** - `min-h-screen` is buggy on mobile Safari +- **Prefer `size-*` utilities** over separate `w-*` and `h-*` when setting equal dimensions +- For max-widths, prefer the container scale (e.g., `max-w-2xs` over `max-w-72`) + +## Typography Rules + +### Line Heights + +- **Never use `leading-*` classes** - Always use line height modifiers with text size +- **Always use fixed line heights from the spacing scale** - Don't use named values + +```html + +

Text with separate line height

+

Text with named line height

+ + +

Text with line height modifier

+

Text with specific line height

+``` + +### Font Size Reference + +Be precise with font sizes - know the actual pixel values: + +- `text-xs` = 12px +- `text-sm` = 14px +- `text-base` = 16px +- `text-lg` = 18px +- `text-xl` = 20px + +## Color and Opacity + +### Opacity Modifiers + +**Never use `bg-opacity-*`, `text-opacity-*`, etc.** - use the opacity modifier syntax: + +```html + +
Old opacity syntax
+ + +
Modern opacity syntax
+``` + +## Responsive Design + +### Breakpoint Optimization + +- **Check for redundant classes across breakpoints** +- **Only add breakpoint variants when values change** + +```html + +
+ +
+ + +
+ +
+``` + +## Dark Mode + +### Dark Mode Best Practices + +- Use the plain `dark:` variant pattern +- Put light mode styles first, then dark mode styles +- Ensure `dark:` variant comes before other variants + +```html + +
+ +
+``` + +## Gradient Utilities + +- **ALWAYS Use `bg-linear-*` instead of `bg-gradient-*` utilities** - The gradient utilities were renamed in v4 +- Use the new `bg-radial` or `bg-radial-[]` to create radial gradients +- Use the new `bg-conic` or `bg-conic-*` to create conic gradients + +```html + +
+
+
+ + +
+``` + +## Working with CSS Variables + +### Accessing Theme Values + +Tailwind CSS v4 exposes all theme values as CSS variables: + +```css +/* Access colors, and other theme values */ +.custom-element { + background: var(--color-red-500); + border-radius: var(--radius-lg); +} +``` + +### The `--spacing()` Function + +Use the dedicated `--spacing()` function for spacing calculations: + +```css +.custom-class { + margin-top: calc(100vh - --spacing(16)); +} +``` + +### Extending theme values + +Use CSS to extend theme values: + +```css +@import "tailwindcss"; + +@theme { + --color-mint-500: oklch(0.72 0.11 178); +} +``` + +```html +
+ +
+``` + +## New v4 Features + +### Container Queries + +Use the `@container` class and size variants: + +```html +
+
+ +
+ +
+
+
+``` + +### Container Query Units + +Use container-based units like `cqw` for responsive sizing: + +```html +
+

Responsive to container width

+
+``` + +### Text Shadows (v4.1) + +Use text-shadow-\* utilities from text-shadow-2xs to text-shadow-lg: + +```html + +

Large shadow

+

Small shadow with opacity

+``` + +### Masking (v4.1) + +Use the new composable mask utilities for image and gradient masks: + +```html + +
Top fade
+
Bottom gradient
+
+ Fade from white to black +
+ + +
+ Radial mask +
+``` + +## Component Patterns + +### Avoiding Utility Inheritance + +Don't add utilities to parents that you override in children: + +```html + +
+

Centered Heading

+
Left-aligned content
+
+ + +
+

Centered Heading

+
Left-aligned content
+
+``` + +### Component Extraction + +- Extract repeated patterns into framework components, not CSS classes +- Keep utility classes in templates/JSX +- Use data attributes for complex state-based styling + +## CSS Best Practices + +### Nesting Guidelines + +- Use nesting when styling both parent and children +- Avoid empty parent selectors + +```css +/* ✅ Good nesting - parent has styles */ +.card { + padding: --spacing(4); + + > .card-title { + font-weight: bold; + } +} + +/* ❌ Avoid empty parents */ +ul { + > li { + /* Parent has no styles */ + } +} +``` + +## Common Pitfalls to Avoid + +1. **Using old opacity utilities** - Always use `/opacity` syntax like `bg-red-500/60` +2. **Redundant breakpoint classes** - Only specify changes +3. **Space utilities in flex/grid** - Always use gap +4. **Leading utilities** - Use line-height modifiers like `text-sm/6` +5. **Arbitrary values** - Use the design scale +6. **@apply directive** - Use components or CSS variables +7. **min-h-screen on mobile** - Use min-h-dvh +8. **Separate width/height** - Use size utilities when equal +9. **Arbitrary values** - Always use Tailwind's predefined scale whenever possible (e.g., use `ml-4` over `ml-[16px]`) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index f59cc6be48..aefcf1b5ac 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -9,6 +9,6 @@ jobs: pull-requests: write runs-on: depot-ubuntu-24.04 steps: - - uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1 + - uses: actions/labeler@f27b608878404679385c85cfa523b85ccb86e213 # v6.1.0 with: sync-labels: true diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index 7f82942e7e..8e16baf933 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -77,10 +77,28 @@ jobs: } if (diff.trim()) { - await github.rest.issues.createComment({ + const body = '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
'; + const marker = '### Pyrefly Diff'; + const { data: comments } = await github.rest.issues.listComments({ issue_number: prNumber, owner: context.repo.owner, repo: context.repo.repo, - body: '### Pyrefly Diff\n
\nbase → PR\n\n```diff\n' + diff + '\n```\n
', }); + const existing = comments.find((comment) => comment.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } } diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index 0cf54e3585..386bd25751 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -103,9 +103,26 @@ jobs: ].join('\n') : '### Pyrefly Diff\nNo changes detected.'; - await github.rest.issues.createComment({ + const marker = '### Pyrefly Diff'; + const { data: comments } = await github.rest.issues.listComments({ issue_number: prNumber, owner: context.repo.owner, repo: context.repo.repo, - body, }); + const existing = comments.find((comment) => comment.body.startsWith(marker)); + + if (existing) { + await github.rest.issues.updateComment({ + comment_id: existing.id, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } else { + await github.rest.issues.createComment({ + issue_number: prNumber, + owner: context.repo.owner, + repo: context.repo.repo, + body, + }); + } diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 7bb6fc1bbd..4e738df684 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -158,7 +158,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.context.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111 + uses: anthropics/claude-code-action@476e359e6203e73dad705c8b322e333fabbd7416 # v1.0.119 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/Makefile b/Makefile index ae7589bbd6..3b5024683f 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,10 @@ DOCKER_REGISTRY=langgenius WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web API_IMAGE=$(DOCKER_REGISTRY)/dify-api VERSION=latest +DOCKER_DIR=docker +DOCKER_MIDDLEWARE_ENV=$(DOCKER_DIR)/middleware.env +DOCKER_MIDDLEWARE_ENV_EXAMPLE=$(DOCKER_DIR)/envs/middleware.env.example +DOCKER_MIDDLEWARE_PROJECT=dify-middlewares-dev # Default target - show help .DEFAULT_GOAL := help @@ -17,8 +21,13 @@ dev-setup: prepare-docker prepare-web prepare-api # Step 1: Prepare Docker middleware prepare-docker: @echo "🐳 Setting up Docker middleware..." - @cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists" - @cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d + @if [ ! -f "$(DOCKER_MIDDLEWARE_ENV)" ]; then \ + cp "$(DOCKER_MIDDLEWARE_ENV_EXAMPLE)" "$(DOCKER_MIDDLEWARE_ENV)"; \ + echo "Docker middleware.env created"; \ + else \ + echo "Docker middleware.env already exists"; \ + fi + @cd $(DOCKER_DIR) && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p $(DOCKER_MIDDLEWARE_PROJECT) up -d @echo "✅ Docker middleware started" # Step 2: Prepare web environment @@ -39,12 +48,18 @@ prepare-api: # Clean dev environment dev-clean: @echo "⚠️ Stopping Docker containers..." - @cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down + @if [ -f "$(DOCKER_MIDDLEWARE_ENV)" ]; then \ + cd $(DOCKER_DIR) && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p $(DOCKER_MIDDLEWARE_PROJECT) down; \ + else \ + echo "Docker middleware.env does not exist, skipping compose down"; \ + fi @echo "🗑️ Removing volumes..." @rm -rf docker/volumes/db + @rm -rf docker/volumes/mysql @rm -rf docker/volumes/redis @rm -rf docker/volumes/plugin_daemon @rm -rf docker/volumes/weaviate + @rm -rf docker/volumes/sandbox/dependencies @rm -rf api/storage @echo "✅ Cleanup complete" @@ -132,7 +147,7 @@ help: @echo " make prepare-docker - Set up Docker middleware" @echo " make prepare-web - Set up web environment" @echo " make prepare-api - Set up API environment" - @echo " make dev-clean - Stop Docker middleware containers" + @echo " make dev-clean - Stop Docker middleware containers and remove dev data" @echo "" @echo "Backend Code Quality:" @echo " make format - Format code with ruff" diff --git a/api/.env.example b/api/.env.example index 56ba8a6c5d..40fed7403c 100644 --- a/api/.env.example +++ b/api/.env.example @@ -34,7 +34,7 @@ TRIGGER_URL=http://localhost:5001 FILES_ACCESS_TIMEOUT=300 # Collaboration mode toggle -ENABLE_COLLABORATION_MODE=false +ENABLE_COLLABORATION_MODE=true # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 @@ -88,6 +88,10 @@ REDIS_HEALTH_CHECK_INTERVAL=30 CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 CELERY_BACKEND=redis +# Ops trace retry configuration +OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60 +OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5 + # Database configuration DB_TYPE=postgresql DB_USERNAME=postgres diff --git a/api/AGENTS.md b/api/AGENTS.md index 8e5d9f600d..eb4404509d 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -193,6 +193,10 @@ Before opening a PR / submitting: - Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic. - Services: coordinate repositories, providers, background tasks; keep side effects explicit. - Document non-obvious behaviour with concise docstrings and comments. +- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`. + In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response + DTOs with `register_response_schema_models(...)`, serialize with `ResponseModel.model_validate(...).model_dump(...)`, + and avoid adding new legacy `ns.model(...)`, `@marshal_with(...)`, or GET `@ns.expect(...)` patterns. ### Miscellaneous diff --git a/api/app_factory.py b/api/app_factory.py index 48e50ceae9..5583071980 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -181,7 +181,6 @@ def initialize_extensions(app: DifyApp): ext_import_modules, ext_orjson, ext_forward_refs, - ext_set_secretkey, ext_compress, ext_code_based_extension, ext_database, @@ -189,6 +188,7 @@ def initialize_extensions(app: DifyApp): ext_migrate, ext_redis, ext_storage, + ext_set_secretkey, ext_logstore, # Initialize logstore after storage, before celery ext_celery, ext_login, diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 249272e308..518e4ec4e6 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -23,9 +23,9 @@ class SecurityConfig(BaseSettings): """ SECRET_KEY: str = Field( - description="Secret key for secure session cookie signing." - "Make sure you are changing this key for your deployment with a strong key." - "Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.", + description="Secret key for secure session cookie signing. " + "Leave empty to let Dify generate a persistent key in the storage directory, " + "or set a strong value via the `SECRET_KEY` environment variable.", default="", ) @@ -1137,6 +1137,18 @@ class MultiModalTransferConfig(BaseSettings): ) +class OpsTraceConfig(BaseSettings): + OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES: PositiveInt = Field( + description="Maximum retry attempts for transient ops trace provider dispatch failures.", + default=60, + ) + + OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS: PositiveInt = Field( + description="Delay in seconds between transient ops trace provider dispatch retry attempts.", + default=5, + ) + + class CeleryBeatConfig(BaseSettings): CELERY_BEAT_SCHEDULER_TIME: int = Field( description="Interval in days for Celery Beat scheduler execution, default to 1 day", @@ -1298,7 +1310,7 @@ class PositionConfig(BaseSettings): class CollaborationConfig(BaseSettings): ENABLE_COLLABORATION_MODE: bool = Field( description="Whether to enable collaboration mode features across the workspace", - default=False, + default=True, ) @@ -1444,6 +1456,7 @@ class FeatureConfig( ModelLoadBalanceConfig, ModerationConfig, MultiModalTransferConfig, + OpsTraceConfig, PositionConfig, RagEtlConfig, RepositoryConfig, diff --git a/api/configs/secret_key.py b/api/configs/secret_key.py new file mode 100644 index 0000000000..f8c33f6a2c --- /dev/null +++ b/api/configs/secret_key.py @@ -0,0 +1,38 @@ +"""SECRET_KEY persistence helpers for runtime setup.""" + +from __future__ import annotations + +import secrets + +from extensions.ext_storage import storage + +GENERATED_SECRET_KEY_FILENAME = ".dify_secret_key" + + +def resolve_secret_key(secret_key: str) -> str: + """Return an explicit SECRET_KEY or a generated key persisted in storage.""" + if secret_key: + return secret_key + + return _load_or_create_secret_key() + + +def _load_or_create_secret_key() -> str: + try: + persisted_key = storage.load_once(GENERATED_SECRET_KEY_FILENAME).decode("utf-8").strip() + if persisted_key: + return persisted_key + except FileNotFoundError: + pass + + generated_key = secrets.token_urlsafe(48) + + try: + storage.save(GENERATED_SECRET_KEY_FILENAME, f"{generated_key}\n".encode()) + except Exception as exc: + raise ValueError( + f"SECRET_KEY is not set and could not be generated at {GENERATED_SECRET_KEY_FILENAME}. " + "Set SECRET_KEY explicitly or make storage writable." + ) from exc + + return generated_key diff --git a/api/controllers/API_SCHEMA_GUIDE.md b/api/controllers/API_SCHEMA_GUIDE.md new file mode 100644 index 0000000000..5b1b055b09 --- /dev/null +++ b/api/controllers/API_SCHEMA_GUIDE.md @@ -0,0 +1,193 @@ +# API Schema Guide + +This guide describes the expected Flask-RESTX + Pydantic pattern for controller request payloads, query +parameters, response schemas, and Swagger documentation. + +## Principles + +- Use Pydantic `BaseModel` for request bodies and query parameters. +- Use `fields.base.ResponseModel` for response DTOs. +- Keep runtime validation and Swagger documentation wired to the same Pydantic model. +- Prefer explicit validation and serialization in controller methods over Flask-RESTX marshalling. +- Do not add new Flask-RESTX `fields.*` dictionaries, `Namespace.model(...)` exports, or `@marshal_with(...)` for migrated or new endpoints. +- Do not use `@ns.expect(...)` for GET query parameters. Flask-RESTX documents that as a request body. + +## Naming + +- Request body models: use a `Payload` suffix. + - Example: `WorkflowRunPayload`, `DatasourceVariablesPayload`. +- Query parameter models: use a `Query` suffix. + - Example: `WorkflowRunListQuery`, `MessageListQuery`. +- Response models: use a `Response` suffix and inherit from `ResponseModel`. + - Example: `WorkflowRunDetailResponse`, `WorkflowRunNodeExecutionListResponse`. +- Use `ListResponse` or `PaginationResponse` for wrapper responses. + - Example: `WorkflowRunNodeExecutionListResponse`, `WorkflowRunPaginationResponse`. +- Keep these models near the controller when they are endpoint-specific. Move them to `fields/*_fields.py` only when shared by multiple controllers. + +## Registering Models For Swagger + +Use helpers from `controllers.common.schema`. + +```python +from controllers.common.schema import ( + query_params_from_model, + register_response_schema_models, + register_schema_models, +) +``` + +Register request payload and query models with `register_schema_models(...)`: + +```python +register_schema_models( + console_ns, + WorkflowRunPayload, + WorkflowRunListQuery, +) +``` + +Register response models with `register_response_schema_models(...)`: + +```python +register_response_schema_models( + console_ns, + WorkflowRunDetailResponse, + WorkflowRunPaginationResponse, +) +``` + +Response models are registered in Pydantic serialization mode. This matters when a response model uses +`validation_alias` to read internal object attributes but emits public API field names. For example, a response model +can validate from `inputs_dict` while documenting and serializing `inputs`. + +## Request Bodies + +For non-GET request bodies: + +1. Define a Pydantic `Payload` model. +2. Register it with `register_schema_models(...)`. +3. Use `@ns.expect(ns.models[Payload.__name__])` for Swagger documentation. +4. Validate from `ns.payload or {}` inside the controller. + +```python +class DraftWorkflowNodeRunPayload(BaseModel): + inputs: dict[str, Any] + query: str = "" + + +register_schema_models(console_ns, DraftWorkflowNodeRunPayload) + + +@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__]) +def post(self, app_model: App, node_id: str): + payload = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {}) + result = service.run(..., inputs=payload.inputs, query=payload.query) + return WorkflowRunNodeExecutionResponse.model_validate(result, from_attributes=True).model_dump(mode="json") +``` + +## Query Parameters + +For GET query parameters: + +1. Define a Pydantic `Query` model. +2. Register it with `register_schema_models(...)` if it is referenced elsewhere in docs, or only use + `query_params_from_model(...)` if a body schema is not needed. +3. Use `@ns.doc(params=query_params_from_model(QueryModel))`. +4. Validate from `request.args.to_dict(flat=True)` or an explicit dict when type coercion is needed. + +```python +class WorkflowRunListQuery(BaseModel): + last_id: str | None = Field(default=None, description="Last run ID for pagination") + limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)") + + +@console_ns.doc(params=query_params_from_model(WorkflowRunListQuery)) +def get(self, app_model: App): + query = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) + result = service.list(..., limit=query.limit, last_id=query.last_id) + return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json") +``` + +Do not do this for GET query parameters: + +```python +@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__]) +def get(...): + ... +``` + +That documents a GET request body and is not the expected contract. + +## Responses + +Response models should inherit from `ResponseModel`: + +```python +class WorkflowRunNodeExecutionResponse(ResponseModel): + id: str + inputs: Any = Field(default=None, validation_alias="inputs_dict") + process_data: Any = Field(default=None, validation_alias="process_data_dict") + outputs: Any = Field(default=None, validation_alias="outputs_dict") +``` + +Document response models with `@ns.response(...)`: + +```python +@console_ns.response( + 200, + "Node run started successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], +) +def post(...): + ... +``` + +Serialize explicitly: + +```python +return WorkflowRunNodeExecutionResponse.model_validate( + workflow_node_execution, + from_attributes=True, +).model_dump(mode="json") +``` + +If the service can return `None`, translate that into the expected HTTP error before validation: + +```python +workflow_run = service.get_workflow_run(...) +if workflow_run is None: + raise NotFound("Workflow run not found") + +return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json") +``` + +## Legacy Flask-RESTX Patterns + +Avoid adding these patterns to new or migrated endpoints: + +- `ns.model(...)` for new request/response DTOs. +- Module-level exported RESTX model objects such as `workflow_run_detail_model`. +- `fields.Nested({...})` with raw inline dict field maps. +- `@marshal_with(...)` for response serialization. +- `@ns.expect(...)` for GET query params. + +Existing legacy field dictionaries may remain where an endpoint has not yet been migrated. Keep that compatibility local +to the legacy area and avoid importing RESTX model objects from controllers. + +## Verifying Swagger + +For schema and documentation changes, run focused tests and generate Swagger JSON: + +```bash +uv run --project . pytest tests/unit_tests/controllers/common/test_schema.py +uv run --project . pytest tests/unit_tests/commands/test_generate_swagger_specs.py tests/unit_tests/controllers/test_swagger.py +uv run --project . dev/generate_swagger_specs.py --output-dir /tmp/dify-openapi-check +``` + +Inspect affected endpoints with `jq`. Check that: + +- GET parameters are `in: query`. +- Request bodies appear only where the endpoint has a body. +- Responses reference the expected `*Response` schema. +- Response schemas use public serialized names, not internal validation aliases like `inputs_dict`. + diff --git a/api/controllers/common/schema.py b/api/controllers/common/schema.py index 57070f1c80..58140f3ac8 100644 --- a/api/controllers/common/schema.py +++ b/api/controllers/common/schema.py @@ -8,7 +8,7 @@ These helpers keep that translation centralized so models registered through from collections.abc import Mapping from enum import StrEnum -from typing import Any, NotRequired, TypedDict +from typing import Any, Literal, NotRequired, TypedDict from flask_restx import Namespace from pydantic import BaseModel, TypeAdapter @@ -54,16 +54,23 @@ def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None _register_json_schema(namespace, nested_name, nested_schema) -def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None: - """Register a BaseModel and its nested schema definitions for Swagger documentation.""" +JsonSchemaMode = Literal["validation", "serialization"] + +def _register_schema_model(namespace: Namespace, model: type[BaseModel], *, mode: JsonSchemaMode) -> None: _register_json_schema( namespace, model.__name__, - model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), + model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0, mode=mode), ) +def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None: + """Register a BaseModel and its nested schema definitions for Swagger documentation.""" + + _register_schema_model(namespace, model, mode="validation") + + def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None: """Register multiple BaseModels with a namespace.""" @@ -71,6 +78,19 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No register_schema_model(namespace, model) +def register_response_schema_model(namespace: Namespace, model: type[BaseModel]) -> None: + """Register a BaseModel using its serialized response shape.""" + + _register_schema_model(namespace, model, mode="serialization") + + +def register_response_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None: + """Register multiple response BaseModels using their serialized response shape.""" + + for model in models: + register_response_schema_model(namespace, model) + + def get_or_create_model(model_name: str, field_def): # Import lazily to avoid circular imports between console controllers and schema helpers. from controllers.console import console_ns @@ -190,6 +210,8 @@ __all__ = [ "get_or_create_model", "query_params_from_model", "register_enum_models", + "register_response_schema_model", + "register_response_schema_models", "register_schema_model", "register_schema_models", ] diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index a32c3420bb..ae2b1007dd 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -3,6 +3,7 @@ import io from collections.abc import Callable from functools import wraps from typing import cast +from uuid import UUID from flask import request from flask_restx import Resource @@ -21,8 +22,6 @@ from libs.token import extract_access_token from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp from services.billing_service import BillingService, LangContentDict -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class InsertExploreAppPayload(BaseModel): app_id: str = Field(...) @@ -59,15 +58,7 @@ class InsertExploreBannerPayload(BaseModel): model_config = {"populate_by_name": True} -console_ns.schema_model( - InsertExploreAppPayload.__name__, - InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - -console_ns.schema_model( - InsertExploreBannerPayload.__name__, - InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, InsertExploreAppPayload, InsertExploreBannerPayload) def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: @@ -191,7 +182,7 @@ class InsertExploreAppApi(Resource): @console_ns.response(204, "App removed successfully") @only_edition_cloud @admin_required - def delete(self, app_id): + def delete(self, app_id: UUID): with session_factory.create_session() as session: recommended_app = session.execute( select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) @@ -404,11 +395,11 @@ class BatchAddNotificationAccountsApi(Resource): raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.") try: - content = file.read().decode("utf-8") + content = file.stream.read().decode("utf-8") except UnicodeDecodeError: try: - file.seek(0) - content = file.read().decode("gbk") + file.stream.seek(0) + content = file.stream.read().decode("gbk") except UnicodeDecodeError: raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.") diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index ed66da1be5..ad21671176 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -34,7 +34,7 @@ class AdvancedPromptTemplateList(Resource): @login_required @account_initialization_required def get(self): - args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) prompt_args: AdvancedPromptTemplateArgs = { "app_mode": args.app_mode, "model_mode": args.model_mode, diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index cfdb9cf417..c05600ced5 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -2,6 +2,7 @@ from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -10,8 +11,6 @@ from libs.login import login_required from models.model import AppMode from services.agent_service import AgentService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AgentLogQuery(BaseModel): message_id: str = Field(..., description="Message UUID") @@ -23,9 +22,7 @@ class AgentLogQuery(BaseModel): return uuid_value(value) -console_ns.schema_model( - AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, AgentLogQuery) @console_ns.route("/apps//agent/logs") @@ -44,6 +41,6 @@ class AgentLogApi(Resource): @get_app_model(mode=[AppMode.AGENT_CHAT]) def get(self, app_model): """Get agent logs""" - args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 528785931e..cfeaec4af9 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,4 +1,5 @@ from typing import Any, Literal +from uuid import UUID from flask import abort, make_response, request from flask_restx import Resource @@ -33,8 +34,6 @@ from services.annotation_service import ( UpsertAnnotationArgs, ) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AnnotationReplyPayload(BaseModel): score_threshold: float = Field(..., description="Score threshold for annotation matching") @@ -87,17 +86,6 @@ class AnnotationFilePayload(BaseModel): return uuid_value(value) -def reg(model: type[BaseModel]) -> None: - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(AnnotationReplyPayload) -reg(AnnotationSettingUpdatePayload) -reg(AnnotationListQuery) -reg(CreateAnnotationPayload) -reg(UpdateAnnotationPayload) -reg(AnnotationReplyStatusQuery) -reg(AnnotationFilePayload) register_schema_models( console_ns, Annotation, @@ -105,6 +93,13 @@ register_schema_models( AnnotationExportList, AnnotationHitHistory, AnnotationHitHistoryList, + AnnotationReplyPayload, + AnnotationSettingUpdatePayload, + AnnotationListQuery, + CreateAnnotationPayload, + UpdateAnnotationPayload, + AnnotationReplyStatusQuery, + AnnotationFilePayload, ) @@ -121,8 +116,7 @@ class AnnotationReplyActionApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def post(self, app_id, action: Literal["enable", "disable"]): - app_id = str(app_id) + def post(self, app_id: UUID, action: Literal["enable", "disable"]): args = AnnotationReplyPayload.model_validate(console_ns.payload) match action: case "enable": @@ -131,9 +125,9 @@ class AnnotationReplyActionApi(Resource): "embedding_provider_name": args.embedding_provider_name, "embedding_model_name": args.embedding_model_name, } - result = AppAnnotationService.enable_app_annotation(enable_args, app_id) + result = AppAnnotationService.enable_app_annotation(enable_args, str(app_id)) case "disable": - result = AppAnnotationService.disable_app_annotation(app_id) + result = AppAnnotationService.disable_app_annotation(str(app_id)) return result, 200 @@ -148,9 +142,8 @@ class AppAnnotationSettingDetailApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id): - app_id = str(app_id) - result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id) + def get(self, app_id: UUID): + result = AppAnnotationService.get_app_annotation_setting_by_app_id(str(app_id)) return result, 200 @@ -166,14 +159,13 @@ class AppAnnotationSettingUpdateApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, app_id, annotation_setting_id): - app_id = str(app_id) + def post(self, app_id: UUID, annotation_setting_id): annotation_setting_id = str(annotation_setting_id) args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload) setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold} - result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args) + result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args) return result, 200 @@ -189,7 +181,7 @@ class AnnotationReplyActionStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def get(self, app_id, job_id, action): + def get(self, app_id: UUID, job_id, action): job_id = str(job_id) app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" cache_result = redis_client.get(app_annotation_job_key) @@ -217,14 +209,13 @@ class AnnotationApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id): - args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + def get(self, app_id: UUID): + args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page limit = args.limit keyword = args.keyword - app_id = str(app_id) - annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) + annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(str(app_id), page, limit, keyword) annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) response = AnnotationList( data=annotation_models, @@ -246,8 +237,7 @@ class AnnotationApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def post(self, app_id): - app_id = str(app_id) + def post(self, app_id: UUID): args = CreateAnnotationPayload.model_validate(console_ns.payload) upsert_args: UpsertAnnotationArgs = {} if args.answer is not None: @@ -258,15 +248,14 @@ class AnnotationApi(Resource): upsert_args["message_id"] = args.message_id if args.question is not None: upsert_args["question"] = args.question - annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id) + annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, str(app_id)) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @account_initialization_required @edit_permission_required - def delete(self, app_id): - app_id = str(app_id) + def delete(self, app_id: UUID): # Use request.args.getlist to get annotation_ids array directly annotation_ids = request.args.getlist("annotation_id") @@ -280,11 +269,11 @@ class AnnotationApi(Resource): "message": "annotation_ids are required if the parameter is provided.", }, 400 - result = AppAnnotationService.delete_app_annotations_in_batch(app_id, annotation_ids) + result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids) return result, 204 # If no annotation_ids are provided, handle clearing all annotations else: - AppAnnotationService.clear_all_annotations(app_id) + AppAnnotationService.clear_all_annotations(str(app_id)) return {"result": "success"}, 204 @@ -303,9 +292,8 @@ class AnnotationExportApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id): - app_id = str(app_id) - annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) + def get(self, app_id: UUID): + annotation_list = AppAnnotationService.export_annotation_list_by_app_id(str(app_id)) annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True) response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json") @@ -331,26 +319,22 @@ class AnnotationUpdateDeleteApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def post(self, app_id, annotation_id): - app_id = str(app_id) - annotation_id = str(annotation_id) + def post(self, app_id: UUID, annotation_id: UUID): args = UpdateAnnotationPayload.model_validate(console_ns.payload) update_args: UpdateAnnotationArgs = {} if args.answer is not None: update_args["answer"] = args.answer if args.question is not None: update_args["question"] = args.question - annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id) + annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id)) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @login_required @account_initialization_required @edit_permission_required - def delete(self, app_id, annotation_id): - app_id = str(app_id) - annotation_id = str(annotation_id) - AppAnnotationService.delete_app_annotation(app_id, annotation_id) + def delete(self, app_id: UUID, annotation_id: UUID): + AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id)) return {"result": "success"}, 204 @@ -371,11 +355,9 @@ class AnnotationBatchImportApi(Resource): @annotation_import_rate_limit @annotation_import_concurrency_limit @edit_permission_required - def post(self, app_id): + def post(self, app_id: UUID): from configs import dify_config - app_id = str(app_id) - # check file if "file" not in request.files: raise NoFileUploadedError() @@ -391,9 +373,9 @@ class AnnotationBatchImportApi(Resource): raise ValueError("Invalid file type. Only CSV files are allowed") # Check file size before processing - file.seek(0, 2) # Seek to end of file - file_size = file.tell() - file.seek(0) # Reset to beginning + file.stream.seek(0, 2) # Seek to end of file + file_size = file.stream.tell() + file.stream.seek(0) # Reset to beginning max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024 if file_size > max_size_bytes: @@ -406,7 +388,7 @@ class AnnotationBatchImportApi(Resource): if file_size == 0: raise ValueError("The uploaded file is empty") - return AppAnnotationService.batch_import_app_annotations(app_id, file) + return AppAnnotationService.batch_import_app_annotations(str(app_id), file) @console_ns.route("/apps//annotations/batch-import-status/") @@ -421,8 +403,7 @@ class AnnotationBatchImportStatusApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("annotation") @edit_permission_required - def get(self, app_id, job_id): - job_id = str(job_id) + def get(self, app_id: UUID, job_id: UUID): indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" cache_result = redis_client.get(indexing_cache_key) if cache_result is None: @@ -456,13 +437,11 @@ class AnnotationHitHistoryListApi(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, app_id, annotation_id): + def get(self, app_id: UUID, annotation_id: UUID): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - app_id = str(app_id) - annotation_id = str(annotation_id) annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( - app_id, annotation_id, page, limit + str(app_id), str(annotation_id), page, limit ) history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python( annotation_hit_history_list, from_attributes=True diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 64bb611cd6..2cf6b7b3bb 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -3,6 +3,7 @@ import re import uuid from datetime import datetime from typing import Any, Literal +from uuid import UUID from flask import request from flask_restx import Resource @@ -40,7 +41,7 @@ from models import App, DatasetPermissionEnum, Workflow from models.model import IconType from models.workflow import resolve_workflow_kind from services.app_dsl_service import AppDslService -from services.app_service import AppService +from services.app_service import AppListParams, AppService, CreateAppParams from services.enterprise.enterprise_service import EnterpriseService from services.enterprise import rbac_service as enterprise_rbac_service from services.entities.dsl_entities import ImportMode, ImportStatus @@ -485,11 +486,18 @@ class AppListApi(Resource): current_user, current_tenant_id = current_account_with_tenant() args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args)) - args_dict = args.model_dump() + params = AppListParams( + page=args.page, + limit=args.limit, + mode=args.mode, + name=args.name, + tag_ids=args.tag_ids, + is_created_by_me=args.is_created_by_me, + ) # get app list app_service = AppService() - app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict) + app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params) if not app_pagination: empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[]) return empty.model_dump(mode="json"), 200 @@ -586,9 +594,17 @@ class AppListApi(Resource): """Create app""" current_user, current_tenant_id = current_account_with_tenant() args = CreateAppPayload.model_validate(console_ns.payload) + params = CreateAppParams( + name=args.name, + description=args.description, + mode=args.mode, + icon_type=args.icon_type, + icon=args.icon, + icon_background=args.icon_background, + ) app_service = AppService() - app = app_service.create_app(current_tenant_id, args.model_dump(), current_user) + app = app_service.create_app(current_tenant_id, params, current_user) app_detail = AppDetail.model_validate(app, from_attributes=True) return app_detail.model_dump(mode="json"), 201 @@ -754,7 +770,7 @@ class AppExportApi(Resource): @edit_permission_required def get(self, app_model): """Export app""" - args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) payload = AppExportResponse( data=AppDslService.export_dsl( @@ -893,10 +909,10 @@ class AppTraceApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): + def get(self, app_id: UUID): """Get app trace""" with session_factory.create_session() as session: - app_trace_config = OpsTraceManager.get_app_tracing_config(app_id, session) + app_trace_config = OpsTraceManager.get_app_tracing_config(str(app_id), session) return app_trace_config @@ -910,12 +926,12 @@ class AppTraceApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, app_id): + def post(self, app_id: UUID): # add app trace args = AppTracePayload.model_validate(console_ns.payload) OpsTraceManager.update_app_tracing_config( - app_id=app_id, + app_id=str(app_id), enabled=args.enabled, tracing_provider=args.tracing_provider, ) diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 91fbe4a85a..5b673f3394 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -173,7 +173,7 @@ class TextModesApi(Resource): @account_initialization_required def get(self, app_model): try: - args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index fe274e4c9a..6a20296cff 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound import services +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -37,7 +38,6 @@ from services.app_task_service import AppTaskService from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class BaseMessagePayload(BaseModel): @@ -65,13 +65,7 @@ class ChatMessagePayload(BaseMessagePayload): return uuid_value(value) -console_ns.schema_model( - CompletionMessagePayload.__name__, - CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload) # define completion message api for user diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index b2b1049f0c..c7347933cb 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -39,8 +39,6 @@ from models.model import AppMode from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class BaseConversationQuery(BaseModel): keyword: str | None = Field(default=None, description="Search keyword") @@ -70,15 +68,6 @@ class ChatConversationQuery(BaseConversationQuery): ) -console_ns.schema_model( - CompletionConversationQuery.__name__, - CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ChatConversationQuery.__name__, - ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) - register_schema_models( console_ns, CompletionConversationQuery, @@ -89,6 +78,8 @@ register_schema_models( ConversationWithSummaryPaginationResponse, ConversationDetailResponse, ResultResponse, + CompletionConversationQuery, + ChatConversationQuery, ) @@ -107,7 +98,7 @@ class CompletionConversationApi(Resource): @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() - args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) query = sa.select(Conversation).where( Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) @@ -221,7 +212,7 @@ class ChatConversationApi(Resource): @edit_permission_required def get(self, app_model): current_user, _ = current_account_with_tenant() - args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) subquery = ( sa.select(Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id")) diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 9c8b095b9f..60a2bfc799 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -100,7 +100,7 @@ class ConversationVariablesApi(Resource): @account_initialization_required @get_app_model(mode=AppMode.ADVANCED_CHAT) def get(self, app_model): - args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) stmt = ( select(ConversationVariable) diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index cbcf513162..9227d00a21 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -1,18 +1,18 @@ from typing import Any +from uuid import UUID from flask import request from flask_restx import Resource, fields from pydantic import BaseModel, Field from werkzeug.exceptions import BadRequest +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.ops_service import OpsService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class TraceProviderQuery(BaseModel): tracing_provider: str = Field(..., description="Tracing provider name") @@ -23,13 +23,7 @@ class TraceConfigPayload(BaseModel): tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data") -console_ns.schema_model( - TraceProviderQuery.__name__, - TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, TraceProviderQuery, TraceConfigPayload) @console_ns.route("/apps//trace-config") @@ -49,11 +43,11 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, app_id): - args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + def get(self, app_id: UUID): + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) try: - trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) + trace_config = OpsService.get_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider) if not trace_config: return {"has_not_configured": True} return trace_config @@ -71,13 +65,13 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, app_id): + def post(self, app_id: UUID): """Create a new trace app configuration""" args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.create_tracing_app_config( - app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigIsExist() @@ -96,13 +90,13 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, app_id): + def patch(self, app_id: UUID): """Update an existing trace app configuration""" args = TraceConfigPayload.model_validate(console_ns.payload) try: result = OpsService.update_tracing_app_config( - app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config + app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config ) if not result: raise TracingConfigNotExist() @@ -119,12 +113,12 @@ class TraceAppConfigApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self, app_id): + def delete(self, app_id: UUID): """Delete an existing trace app configuration""" - args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) try: - result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider) + result = OpsService.delete_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider) if not result: raise TracingConfigNotExist() return {"result": "success"}, 204 diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index ffa28b1c95..d23b2837c9 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -5,6 +5,7 @@ from flask import abort, jsonify, request from flask_restx import Resource, fields from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -15,8 +16,6 @@ from libs.helper import convert_datetime_to_date from libs.login import current_account_with_tenant, login_required from models import AppMode -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class StatisticTimeRangeQuery(BaseModel): start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)") @@ -30,10 +29,7 @@ class StatisticTimeRangeQuery(BaseModel): return value -console_ns.schema_model( - StatisticTimeRangeQuery.__name__, - StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, StatisticTimeRangeQuery) @console_ns.route("/apps//statistics/daily-messages") @@ -54,7 +50,7 @@ class DailyMessageStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -111,7 +107,7 @@ class DailyConversationStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -167,7 +163,7 @@ class DailyTerminalsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -224,7 +220,7 @@ class DailyTokenCostStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -284,7 +280,7 @@ class AverageSessionInteractionStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("c.created_at") sql_query = f"""SELECT @@ -360,7 +356,7 @@ class UserSatisfactionRateStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("m.created_at") sql_query = f"""SELECT @@ -426,7 +422,7 @@ class AverageResponseTimeStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT @@ -482,7 +478,7 @@ class TokensPerSecondStatistic(Resource): @account_initialization_required def get(self, app_model): account, _ = current_account_with_tenant() - args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) converted_created_at = convert_datetime_to_date("created_at") sql_query = f"""SELECT diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 8d7b3f1d99..e4d9d01485 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -10,10 +10,10 @@ from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services -from controllers.common.controller_schemas import DefaultBlockConfigQuery +from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload +from controllers.common.schema import register_response_schema_model, register_schema_models from controllers.console import console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync -from controllers.console.app.workflow_run import workflow_run_node_execution_model from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError @@ -37,6 +37,7 @@ from factories import file_factory, variable_factory from fields.member_fields import simple_account_fields from fields.online_user_fields import online_user_list_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields +from fields.workflow_run_fields import WorkflowRunNodeExecutionResponse from graphon.enums import NodeType from graphon.file import File from graphon.file import helpers as file_helpers @@ -56,9 +57,10 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) + _file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" + RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000 WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50 @@ -213,7 +215,7 @@ reg(WorkflowFeaturesPayload) reg(WorkflowOnlineUsersPayload) reg(DraftWorkflowTriggerRunPayload) reg(DraftWorkflowTriggerRunAllPayload) - +register_response_schema_model(console_ns, WorkflowRunNodeExecutionResponse) # TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing # at the controller level rather than in the workflow logic. This would improve separation @@ -558,9 +560,12 @@ class HumanInputDeliveryTestPayload(BaseModel): ) -reg(HumanInputFormPreviewPayload) -reg(HumanInputFormSubmitPayload) -reg(HumanInputDeliveryTestPayload) +register_schema_models( + console_ns, + HumanInputFormPreviewPayload, + HumanInputFormSubmitPayload, + HumanInputDeliveryTestPayload, +) @console_ns.route("/apps//advanced-chat/workflows/draft/human-input/nodes//form/preview") @@ -778,14 +783,17 @@ class DraftWorkflowNodeRunApi(Resource): @console_ns.doc(description="Run draft workflow node") @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__]) - @console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model) + @console_ns.response( + 200, + "Node run started successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], + ) @console_ns.response(403, "Permission denied") @console_ns.response(404, "Node not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_node_execution_model) @edit_permission_required def post(self, app_model: App, node_id: str): """ @@ -817,7 +825,9 @@ class DraftWorkflowNodeRunApi(Resource): files=files, ) - return workflow_node_execution + return WorkflowRunNodeExecutionResponse.model_validate( + workflow_node_execution, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/apps//workflows/publish") @@ -968,7 +978,7 @@ class DefaultBlockConfigApi(Resource): """ Get default block config """ - args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) filters = None if args.q: @@ -1061,7 +1071,7 @@ class PublishedAllWorkflowApi(Resource): """ current_user, _ = current_account_with_tenant() - args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page limit = args.limit user_id = args.user_id @@ -1255,14 +1265,17 @@ class DraftWorkflowNodeLastRunApi(Resource): @console_ns.doc("get_draft_workflow_node_last_run") @console_ns.doc(description="Get last run result for draft workflow node") @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) - @console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model) + @console_ns.response( + 200, + "Node last run retrieved successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], + ) @console_ns.response(404, "Node last run not found") @console_ns.response(403, "Permission denied") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_node_execution_model) def get(self, app_model: App, node_id: str): srv = WorkflowService() workflow = srv.get_draft_workflow(app_model) @@ -1275,7 +1288,7 @@ class DraftWorkflowNodeLastRunApi(Resource): ) if node_exec is None: raise NotFound("last run not found") - return node_exec + return WorkflowRunNodeExecutionResponse.model_validate(node_exec, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//workflows/draft/trigger/run") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 1502ba07a3..3a8019941f 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -210,7 +210,7 @@ class WorkflowAppLogApi(Resource): """ Get workflow app logs """ - args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # get paginate workflow app logs workflow_app_service = WorkflowAppService() @@ -253,7 +253,7 @@ class WorkflowArchivedLogApi(Resource): """ Get workflow archived logs """ - args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) workflow_app_service = WorkflowAppService() with sessionmaker(db.engine, expire_on_commit=False).begin() as session: diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py index e7c3e982a6..c003be1303 100644 --- a/api/controllers/console/app/workflow_comment.py +++ b/api/controllers/console/app/workflow_comment.py @@ -23,7 +23,6 @@ from services.account_service import TenantService from services.workflow_comment_service import WorkflowCommentService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class WorkflowCommentCreatePayload(BaseModel): @@ -52,13 +51,14 @@ class WorkflowCommentMentionUsersPayload(BaseModel): users: list[AccountWithRole] -for model in ( +register_schema_models( + console_ns, + AccountWithRole, + WorkflowCommentMentionUsersPayload, WorkflowCommentCreatePayload, WorkflowCommentUpdatePayload, WorkflowCommentReplyPayload, -): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) -register_schema_models(console_ns, AccountWithRole, WorkflowCommentMentionUsersPayload) +) workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields) workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index c688a69074..3c887c33dc 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -8,6 +8,7 @@ from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( DraftWorkflowNotExist, @@ -33,7 +34,6 @@ from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) _file_access_controller = DatabaseFileAccessController() -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class WorkflowDraftVariableListQuery(BaseModel): @@ -56,21 +56,12 @@ class EnvironmentVariableUpdatePayload(BaseModel): environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow") -console_ns.schema_model( - WorkflowDraftVariableListQuery.__name__, - WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - WorkflowDraftVariableUpdatePayload.__name__, - WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - ConversationVariableUpdatePayload.__name__, - ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) -console_ns.schema_model( - EnvironmentVariableUpdatePayload.__name__, - EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +register_schema_models( + console_ns, + WorkflowDraftVariableListQuery, + WorkflowDraftVariableUpdatePayload, + ConversationVariableUpdatePayload, + EnvironmentVariableUpdatePayload, ) @@ -260,7 +251,7 @@ class WorkflowVariableCollectionApi(Resource): """ Get draft workflow """ - args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # fetch draft workflow by app_model workflow_service = WorkflowService() diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 6748d95d6b..97d2003209 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,30 +1,28 @@ from datetime import UTC, datetime, timedelta -from typing import Literal, TypedDict, cast +from typing import Literal, cast from flask import request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker from configs import dify_config +from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import NotFoundError from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id from extensions.ext_database import db -from fields.end_user_fields import simple_end_user_fields -from fields.member_fields import simple_account_fields +from fields.base import ResponseModel from fields.workflow_run_fields import ( - advanced_chat_workflow_run_for_list_fields, - advanced_chat_workflow_run_pagination_fields, - workflow_run_count_fields, - workflow_run_detail_fields, - workflow_run_for_list_fields, - workflow_run_node_execution_fields, - workflow_run_node_execution_list_fields, - workflow_run_pagination_fields, + AdvancedChatWorkflowRunPaginationResponse, + WorkflowRunCountResponse, + WorkflowRunDetailResponse, + WorkflowRunNodeExecutionListResponse, + WorkflowRunNodeExecutionResponse, + WorkflowRunPaginationResponse, ) from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import WorkflowExecutionStatus @@ -52,82 +50,6 @@ def _build_backstage_input_url(form_token: str | None) -> str | None: WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600 -# Register models for flask_restx to avoid dict type issues in Swagger -# Register in dependency order: base models first, then dependent models - -# Base models -simple_account_model = console_ns.model("SimpleAccount", simple_account_fields) - -simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields) - -# Models that depend on simple_account_fields -workflow_run_for_list_fields_copy = workflow_run_for_list_fields.copy() -workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True -) -workflow_run_for_list_model = console_ns.model("WorkflowRunForList", workflow_run_for_list_fields_copy) - -advanced_chat_workflow_run_for_list_fields_copy = advanced_chat_workflow_run_for_list_fields.copy() -advanced_chat_workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True -) -advanced_chat_workflow_run_for_list_model = console_ns.model( - "AdvancedChatWorkflowRunForList", advanced_chat_workflow_run_for_list_fields_copy -) - -workflow_run_detail_fields_copy = workflow_run_detail_fields.copy() -workflow_run_detail_fields_copy["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True -) -workflow_run_detail_fields_copy["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True -) -workflow_run_detail_model = console_ns.model("WorkflowRunDetail", workflow_run_detail_fields_copy) - -workflow_run_node_execution_fields_copy = workflow_run_node_execution_fields.copy() -workflow_run_node_execution_fields_copy["created_by_account"] = fields.Nested( - simple_account_model, attribute="created_by_account", allow_null=True -) -workflow_run_node_execution_fields_copy["created_by_end_user"] = fields.Nested( - simple_end_user_model, attribute="created_by_end_user", allow_null=True -) -workflow_run_node_execution_model = console_ns.model( - "WorkflowRunNodeExecution", workflow_run_node_execution_fields_copy -) - -# Simple models without nested dependencies -workflow_run_count_model = console_ns.model("WorkflowRunCount", workflow_run_count_fields) - -# Pagination models that depend on list models -advanced_chat_workflow_run_pagination_fields_copy = advanced_chat_workflow_run_pagination_fields.copy() -advanced_chat_workflow_run_pagination_fields_copy["data"] = fields.List( - fields.Nested(advanced_chat_workflow_run_for_list_model), attribute="data" -) -advanced_chat_workflow_run_pagination_model = console_ns.model( - "AdvancedChatWorkflowRunPagination", advanced_chat_workflow_run_pagination_fields_copy -) - -workflow_run_pagination_fields_copy = workflow_run_pagination_fields.copy() -workflow_run_pagination_fields_copy["data"] = fields.List(fields.Nested(workflow_run_for_list_model), attribute="data") -workflow_run_pagination_model = console_ns.model("WorkflowRunPagination", workflow_run_pagination_fields_copy) - -workflow_run_node_execution_list_fields_copy = workflow_run_node_execution_list_fields.copy() -workflow_run_node_execution_list_fields_copy["data"] = fields.List(fields.Nested(workflow_run_node_execution_model)) -workflow_run_node_execution_list_model = console_ns.model( - "WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy -) - -workflow_run_export_fields = console_ns.model( - "WorkflowRunExport", - { - "status": fields.String(description="Export status: success/failed"), - "presigned_url": fields.String(description="Pre-signed URL for download", required=False), - "presigned_url_expires_at": fields.String(description="Pre-signed URL expiration time", required=False), - }, -) - -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class WorkflowRunListQuery(BaseModel): last_id: str | None = Field(default=None, description="Last run ID for pagination") @@ -136,7 +58,7 @@ class WorkflowRunListQuery(BaseModel): default=None, description="Workflow run status filter" ) triggered_from: Literal["debugging", "app-run"] | None = Field( - default=None, description="Filter by trigger source: debugging or app-run" + default=None, description="Filter by trigger source: debugging or app-run. Default: debugging" ) @field_validator("last_id") @@ -151,9 +73,15 @@ class WorkflowRunCountQuery(BaseModel): status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field( default=None, description="Workflow run status filter" ) - time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)") + time_range: str | None = Field( + default=None, + description=( + "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " + "30m (30 minutes), 30s (30 seconds). Filters by created_at field." + ), + ) triggered_from: Literal["debugging", "app-run"] | None = Field( - default=None, description="Filter by trigger source: debugging or app-run" + default=None, description="Filter by trigger source: debugging or app-run. Default: debugging" ) @field_validator("time_range") @@ -164,56 +92,69 @@ class WorkflowRunCountQuery(BaseModel): return time_duration(value) -console_ns.schema_model( - WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - WorkflowRunCountQuery.__name__, - WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +class WorkflowRunExportResponse(ResponseModel): + status: str = Field(description="Export status: success/failed") + presigned_url: str | None = Field(default=None, description="Pre-signed URL for download") + presigned_url_expires_at: str | None = Field(default=None, description="Pre-signed URL expiration time") -class HumanInputPauseTypeResponse(TypedDict): +class HumanInputPauseTypeResponse(ResponseModel): type: Literal["human_input"] form_id: str - backstage_input_url: str | None + backstage_input_url: str | None = None -class PausedNodeResponse(TypedDict): +class PausedNodeResponse(ResponseModel): node_id: str node_title: str pause_type: HumanInputPauseTypeResponse -class WorkflowPauseDetailsResponse(TypedDict): - paused_at: str | None +class WorkflowPauseDetailsResponse(ResponseModel): + paused_at: str | None = None paused_nodes: list[PausedNodeResponse] +register_schema_models( + console_ns, + WorkflowRunListQuery, + WorkflowRunCountQuery, +) +register_response_schema_models( + console_ns, + AdvancedChatWorkflowRunPaginationResponse, + WorkflowRunPaginationResponse, + WorkflowRunCountResponse, + WorkflowRunDetailResponse, + WorkflowRunNodeExecutionResponse, + WorkflowRunNodeExecutionListResponse, + WorkflowRunExportResponse, + HumanInputPauseTypeResponse, + PausedNodeResponse, + WorkflowPauseDetailsResponse, +) + + @console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs") @console_ns.doc(description="Get advanced chat workflow run list") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) - @console_ns.doc( - params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} + @console_ns.doc(params=query_params_from_model(WorkflowRunListQuery)) + @console_ns.response( + 200, + "Workflow runs retrieved successfully", + console_ns.models[AdvancedChatWorkflowRunPaginationResponse.__name__], ) - @console_ns.doc( - params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} - ) - @console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__]) - @console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) - @marshal_with(advanced_chat_workflow_run_pagination_model) def get(self, app_model: App): """ Get advanced chat app workflow run list """ - args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) args: WorkflowRunListArgs = {"limit": args_model.limit} if args_model.last_id is not None: args["last_id"] = args_model.last_id @@ -232,7 +173,9 @@ class AdvancedChatAppWorkflowRunListApi(Resource): app_model=app_model, args=args, triggered_from=triggered_from ) - return result + return AdvancedChatWorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump( + mode="json" + ) @console_ns.route("/apps//workflow-runs//export") @@ -240,7 +183,7 @@ class WorkflowRunExportApi(Resource): @console_ns.doc("get_workflow_run_export_url") @console_ns.doc(description="Generate a download URL for an archived workflow run.") @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) - @console_ns.response(200, "Export URL generated", workflow_run_export_fields) + @console_ns.response(200, "Export URL generated", console_ns.models[WorkflowRunExportResponse.__name__]) @setup_required @login_required @account_initialization_required @@ -278,11 +221,14 @@ class WorkflowRunExportApi(Resource): expires_in=EXPORT_SIGNED_URL_EXPIRE_SECONDS, ) expires_at = datetime.now(UTC) + timedelta(seconds=EXPORT_SIGNED_URL_EXPIRE_SECONDS) - return { - "status": "success", - "presigned_url": presigned_url, - "presigned_url_expires_at": expires_at.isoformat(), - }, 200 + response = WorkflowRunExportResponse.model_validate( + { + "status": "success", + "presigned_url": presigned_url, + "presigned_url_expires_at": expires_at.isoformat(), + } + ) + return response.model_dump(mode="json"), 200 @console_ns.route("/apps//advanced-chat/workflow-runs/count") @@ -290,32 +236,21 @@ class AdvancedChatAppWorkflowRunCountApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs_count") @console_ns.doc(description="Get advanced chat workflow runs count statistics") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc( - params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} + @console_ns.doc(params=query_params_from_model(WorkflowRunCountQuery)) + @console_ns.response( + 200, + "Workflow runs count retrieved successfully", + console_ns.models[WorkflowRunCountResponse.__name__], ) - @console_ns.doc( - params={ - "time_range": ( - "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " - "30m (30 minutes), 30s (30 seconds). Filters by created_at field." - ) - } - ) - @console_ns.doc( - params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} - ) - @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model) - @console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__]) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) - @marshal_with(workflow_run_count_model) def get(self, app_model: App): """ Get advanced chat workflow runs count statistics """ - args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) args = args_model.model_dump(exclude_none=True) # Default to DEBUGGING if not specified @@ -333,7 +268,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource): triggered_from=triggered_from, ) - return result + return WorkflowRunCountResponse.model_validate(result).model_dump(mode="json") @console_ns.route("/apps//workflow-runs") @@ -341,25 +276,21 @@ class WorkflowRunListApi(Resource): @console_ns.doc("get_workflow_runs") @console_ns.doc(description="Get workflow run list") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) - @console_ns.doc( - params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} + @console_ns.doc(params=query_params_from_model(WorkflowRunListQuery)) + @console_ns.response( + 200, + "Workflow runs retrieved successfully", + console_ns.models[WorkflowRunPaginationResponse.__name__], ) - @console_ns.doc( - params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} - ) - @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model) - @console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__]) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_pagination_model) def get(self, app_model: App): """ Get workflow run list """ - args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) args: WorkflowRunListArgs = {"limit": args_model.limit} if args_model.last_id is not None: args["last_id"] = args_model.last_id @@ -378,7 +309,7 @@ class WorkflowRunListApi(Resource): app_model=app_model, args=args, triggered_from=triggered_from ) - return result + return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//workflow-runs/count") @@ -386,32 +317,21 @@ class WorkflowRunCountApi(Resource): @console_ns.doc("get_workflow_runs_count") @console_ns.doc(description="Get workflow runs count statistics") @console_ns.doc(params={"app_id": "Application ID"}) - @console_ns.doc( - params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} + @console_ns.doc(params=query_params_from_model(WorkflowRunCountQuery)) + @console_ns.response( + 200, + "Workflow runs count retrieved successfully", + console_ns.models[WorkflowRunCountResponse.__name__], ) - @console_ns.doc( - params={ - "time_range": ( - "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " - "30m (30 minutes), 30s (30 seconds). Filters by created_at field." - ) - } - ) - @console_ns.doc( - params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} - ) - @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model) - @console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__]) @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_count_model) def get(self, app_model: App): """ Get workflow runs count statistics """ - args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) args = args_model.model_dump(exclude_none=True) # Default to DEBUGGING for workflow if not specified (backward compatibility) @@ -429,7 +349,7 @@ class WorkflowRunCountApi(Resource): triggered_from=triggered_from, ) - return result + return WorkflowRunCountResponse.model_validate(result).model_dump(mode="json") @console_ns.route("/apps//workflow-runs/") @@ -437,13 +357,16 @@ class WorkflowRunDetailApi(Resource): @console_ns.doc("get_workflow_run_detail") @console_ns.doc(description="Get workflow run detail") @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) - @console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model) + @console_ns.response( + 200, + "Workflow run detail retrieved successfully", + console_ns.models[WorkflowRunDetailResponse.__name__], + ) @console_ns.response(404, "Workflow run not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_detail_model) def get(self, app_model: App, run_id): """ Get workflow run detail @@ -452,8 +375,10 @@ class WorkflowRunDetailApi(Resource): workflow_run_service = WorkflowRunService() workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id) + if workflow_run is None: + raise NotFoundError("Workflow run not found") - return workflow_run + return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json") @console_ns.route("/apps//workflow-runs//node-executions") @@ -461,13 +386,16 @@ class WorkflowRunNodeExecutionListApi(Resource): @console_ns.doc("get_workflow_run_node_executions") @console_ns.doc(description="Get workflow run node execution list") @console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) - @console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model) + @console_ns.response( + 200, + "Node executions retrieved successfully", + console_ns.models[WorkflowRunNodeExecutionListResponse.__name__], + ) @console_ns.response(404, "Workflow run not found") @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_run_node_execution_list_model) def get(self, app_model: App, run_id): """ Get workflow run node execution list @@ -482,13 +410,24 @@ class WorkflowRunNodeExecutionListApi(Resource): user=user, ) - return {"data": node_executions} + return WorkflowRunNodeExecutionListResponse.model_validate( + {"data": node_executions}, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/workflow//pause-details") class ConsoleWorkflowPauseDetailsApi(Resource): """Console API for getting workflow pause details.""" + @console_ns.doc("get_workflow_pause_details") + @console_ns.doc(description="Get workflow pause details") + @console_ns.doc(params={"workflow_run_id": "Workflow run ID"}) + @console_ns.response( + 200, + "Workflow pause details retrieved successfully", + console_ns.models[WorkflowPauseDetailsResponse.__name__], + ) + @console_ns.response(404, "Workflow run not found") @setup_required @login_required @account_initialization_required @@ -515,11 +454,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Check if workflow is suspended is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED if not is_paused: - empty_response: WorkflowPauseDetailsResponse = { - "paused_at": None, - "paused_nodes": [], - } - return empty_response, 200 + empty_response = WorkflowPauseDetailsResponse(paused_at=None, paused_nodes=[]) + return empty_response.model_dump(mode="json"), 200 pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] @@ -530,27 +466,25 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Build response paused_at = pause_entity.paused_at if pause_entity else None paused_nodes: list[PausedNodeResponse] = [] - response: WorkflowPauseDetailsResponse = { - "paused_at": paused_at.isoformat() + "Z" if paused_at else None, - "paused_nodes": paused_nodes, - } for reason in pause_reasons: if isinstance(reason, HumanInputRequired): paused_nodes.append( - { - "node_id": reason.node_id, - "node_title": reason.node_title, - "pause_type": { - "type": "human_input", - "form_id": reason.form_id, - "backstage_input_url": _build_backstage_input_url( - form_tokens_by_form_id.get(reason.form_id) - ), - }, - } + PausedNodeResponse( + node_id=reason.node_id, + node_title=reason.node_title, + pause_type=HumanInputPauseTypeResponse( + type="human_input", + form_id=reason.form_id, + backstage_input_url=_build_backstage_input_url(form_tokens_by_form_id.get(reason.form_id)), + ), + ) ) else: raise AssertionError("unimplemented.") - return response, 200 + response = WorkflowPauseDetailsResponse( + paused_at=paused_at.isoformat() + "Z" if paused_at else None, + paused_nodes=paused_nodes, + ) + return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index e48cf42762..ca899d8784 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -3,6 +3,7 @@ from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import sessionmaker +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required @@ -13,8 +14,6 @@ from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class WorkflowStatisticQuery(BaseModel): start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)") @@ -28,10 +27,7 @@ class WorkflowStatisticQuery(BaseModel): return value -console_ns.schema_model( - WorkflowStatisticQuery.__name__, - WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, WorkflowStatisticQuery) @console_ns.route("/apps//workflow/statistics/daily-conversations") @@ -53,7 +49,7 @@ class WorkflowDailyRunsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -93,7 +89,7 @@ class WorkflowDailyTerminalsStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -133,7 +129,7 @@ class WorkflowDailyTokenCostStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None @@ -173,7 +169,7 @@ class WorkflowAverageAppInteractionStatistic(Resource): def get(self, app_model): account, _ = current_account_with_tenant() - args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) assert account.timezone is not None diff --git a/api/controllers/console/app/workflow_trigger.py b/api/controllers/console/app/workflow_trigger.py index a6715fa200..a80b4f5d0c 100644 --- a/api/controllers/console/app/workflow_trigger.py +++ b/api/controllers/console/app/workflow_trigger.py @@ -94,7 +94,7 @@ class WebhookTriggerApi(Resource): @console_ns.response(200, "Success", console_ns.models[WebhookTriggerResponse.__name__]) def get(self, app_model: App): """Get webhook trigger for a node""" - args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = Parser.model_validate(request.args.to_dict(flat=True)) node_id = args.node_id diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index f7061f820f..0c05cf2fe3 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -63,7 +63,7 @@ class ActivateCheckApi(Resource): console_ns.models[ActivationCheckResponse.__name__], ) def get(self): - args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) workspaceId = args.workspace_id token = args.token diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 905d0daef0..db0d36af6e 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -1,6 +1,7 @@ from flask_restx import Resource from pydantic import BaseModel, Field +from controllers.common.schema import register_schema_models from libs.login import current_account_with_tenant, login_required from services.auth.api_key_auth_service import ApiKeyAuthService @@ -8,8 +9,6 @@ from .. import console_ns from ..auth.error import ApiKeyAuthFailedError from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ApiKeyAuthBindingPayload(BaseModel): category: str = Field(...) @@ -17,10 +16,7 @@ class ApiKeyAuthBindingPayload(BaseModel): credentials: dict = Field(...) -console_ns.schema_model( - ApiKeyAuthBindingPayload.__name__, - ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_schema_models(console_ns, ApiKeyAuthBindingPayload) @console_ns.route("/api-key-auth/data-source") diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 1fd781b4fc..f6b8aedf22 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field, field_validator from configs import dify_config from constants.languages import languages +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( EmailAlreadyInUseError, @@ -23,8 +24,6 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError from ..error import AccountInFreezeError, EmailSendIpLimitError from ..wraps import email_password_login_enabled, email_register_enabled, setup_required -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class EmailRegisterSendPayload(BaseModel): email: EmailStr = Field(..., description="Email address") @@ -48,8 +47,7 @@ class EmailRegisterResetPayload(BaseModel): return valid_password(value) -for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload): - console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload) @console_ns.route("/email-register/send-email") diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index ed390a5f89..c34dd1ac85 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -28,8 +28,6 @@ from services.entities.auth_entities import ( ) from services.feature_service import FeatureService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ForgotPasswordEmailResponse(BaseModel): result: str = Field(description="Operation result") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 8216b3d0da..19c98f3a1a 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -9,6 +9,7 @@ from werkzeug.exceptions import Unauthorized import services from configs import dify_config from constants.languages import get_valid_language +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.auth.error import ( AuthenticationFailedError, @@ -50,7 +51,6 @@ from services.errors.account import AccountRegisterError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" logger = logging.getLogger(__name__) @@ -71,13 +71,7 @@ class EmailCodeLoginPayload(BaseModel): language: str | None = Field(default=None) -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(LoginPayload) -reg(EmailPayload) -reg(EmailCodeLoginPayload) +register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload) @console_ns.route("/login") diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index a883af0d13..054aa7682e 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -644,63 +644,63 @@ class DatasetIndexingEstimateApi(Resource): # validate args DocumentService.estimate_args_validate(args) extract_settings = [] - if args["info_list"]["data_source_type"] == "upload_file": - file_ids = args["info_list"]["file_info_list"]["file_ids"] - file_details = db.session.scalars( - select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids)) - ).all() + match args["info_list"]["data_source_type"]: + case "upload_file": + file_ids = args["info_list"]["file_info_list"]["file_ids"] + file_details = db.session.scalars( + select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids)) + ).all() + if file_details is None: + raise NotFound("File not found.") - if file_details is None: - raise NotFound("File not found.") - - if file_details: - for file_detail in file_details: + if file_details: + for file_detail in file_details: + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, + upload_file=file_detail, + document_model=args["doc_form"], + ) + extract_settings.append(extract_setting) + case "notion_import": + notion_info_list = args["info_list"]["notion_info_list"] + for notion_info in notion_info_list: + workspace_id = notion_info["workspace_id"] + credential_id = notion_info.get("credential_id") + for page in notion_info["pages"]: + extract_setting = ExtractSetting( + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": credential_id, + "notion_workspace_id": workspace_id, + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_tenant_id, + } + ), + document_model=args["doc_form"], + ) + extract_settings.append(extract_setting) + case "website_crawl": + website_info_list = args["info_list"]["website_info_list"] + for url in website_info_list["urls"]: extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, - upload_file=file_detail, - document_model=args["doc_form"], - ) - extract_settings.append(extract_setting) - elif args["info_list"]["data_source_type"] == "notion_import": - notion_info_list = args["info_list"]["notion_info_list"] - for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] - credential_id = notion_info.get("credential_id") - for page in notion_info["pages"]: - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( { - "credential_id": credential_id, - "notion_workspace_id": workspace_id, - "notion_obj_id": page["page_id"], - "notion_page_type": page["type"], + "provider": website_info_list["provider"], + "job_id": website_info_list["job_id"], + "url": url, "tenant_id": current_tenant_id, + "mode": "crawl", + "only_main_content": website_info_list["only_main_content"], } ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) - elif args["info_list"]["data_source_type"] == "website_crawl": - website_info_list = args["info_list"]["website_info_list"] - for url in website_info_list["urls"]: - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": website_info_list["provider"], - "job_id": website_info_list["job_id"], - "url": url, - "tenant_id": current_tenant_id, - "mode": "crawl", - "only_main_content": website_info_list["only_main_content"], - } - ), - document_model=args["doc_form"], - ) - extract_settings.append(extract_setting) - else: - raise ValueError("Data source type not support") + case _: + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: response = indexing_runner.indexing_estimate( diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index b61e39716f..3ad5cf1990 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -286,28 +286,31 @@ class DatasetDocumentListApi(Resource): else: sort_logic = asc - if sort == "hit_count": - sub_query = ( - sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")) - .where(DocumentSegment.dataset_id == str(dataset_id)) - .group_by(DocumentSegment.document_id) - .subquery() - ) + match sort: + case "hit_count": + sub_query = ( + sa.select( + DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count") + ) + .where(DocumentSegment.dataset_id == str(dataset_id)) + .group_by(DocumentSegment.document_id) + .subquery() + ) - query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( - sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)), - sort_logic(Document.position), - ) - elif sort == "created_at": - query = query.order_by( - sort_logic(Document.created_at), - sort_logic(Document.position), - ) - else: - query = query.order_by( - desc(Document.created_at), - desc(Document.position), - ) + query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( + sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) + case "created_at": + query = query.order_by( + sort_logic(Document.created_at), + sort_logic(Document.position), + ) + case _: + query = query.order_by( + desc(Document.created_at), + desc(Document.position), + ) paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py index 7caf5b52ed..a43caa8f56 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_content_preview.py @@ -4,6 +4,7 @@ from flask_restx import ( # type: ignore from pydantic import BaseModel from werkzeug.exceptions import Forbidden +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required @@ -12,8 +13,6 @@ from models import Account from models.dataset import Pipeline from services.rag_pipeline.rag_pipeline import RagPipelineService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class Parser(BaseModel): inputs: dict @@ -21,7 +20,7 @@ class Parser(BaseModel): credential_id: str | None = None -console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) +register_schema_models(console_ns, Parser) @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//preview") diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index ee146e8287..8eff32c555 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -10,7 +10,7 @@ from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotF import services from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload -from controllers.common.schema import register_schema_models +from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( ConversationCompletedError, @@ -22,12 +22,6 @@ from controllers.console.app.workflow import ( workflow_model, workflow_pagination_model, ) -from controllers.console.app.workflow_run import ( - workflow_run_detail_model, - workflow_run_node_execution_list_model, - workflow_run_node_execution_model, - workflow_run_pagination_model, -) from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import ( account_initialization_required, @@ -40,6 +34,12 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from factories import variable_factory +from fields.workflow_run_fields import ( + WorkflowRunDetailResponse, + WorkflowRunNodeExecutionListResponse, + WorkflowRunNodeExecutionResponse, + WorkflowRunPaginationResponse, +) from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty @@ -131,6 +131,13 @@ register_schema_models( DatasourceVariablesPayload, RagPipelineRecommendedPluginQuery, ) +register_response_schema_models( + console_ns, + WorkflowRunDetailResponse, + WorkflowRunNodeExecutionListResponse, + WorkflowRunNodeExecutionResponse, + WorkflowRunPaginationResponse, +) @console_ns.route("/rag/pipelines//workflows/draft") @@ -415,12 +422,16 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/nodes//run") class RagPipelineDraftNodeRunApi(Resource): @console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__]) + @console_ns.response( + 200, + "Node run started successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], + ) @setup_required @login_required @edit_permission_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_model) def post(self, pipeline: Pipeline, node_id: str): """ Run draft workflow node @@ -439,7 +450,9 @@ class RagPipelineDraftNodeRunApi(Resource): if workflow_node_execution is None: raise ValueError("Workflow node execution not found") - return workflow_node_execution + return WorkflowRunNodeExecutionResponse.model_validate( + workflow_node_execution, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/rag/pipelines//workflow-runs/tasks//stop") @@ -778,11 +791,15 @@ class DraftRagPipelineSecondStepApi(Resource): @console_ns.route("/rag/pipelines//workflow-runs") class RagPipelineWorkflowRunListApi(Resource): + @console_ns.response( + 200, + "Workflow runs retrieved successfully", + console_ns.models[WorkflowRunPaginationResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_pagination_model) def get(self, pipeline: Pipeline): """ Get workflow run list @@ -801,16 +818,20 @@ class RagPipelineWorkflowRunListApi(Resource): rag_pipeline_service = RagPipelineService() result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) - return result + return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json") @console_ns.route("/rag/pipelines//workflow-runs/") class RagPipelineWorkflowRunDetailApi(Resource): + @console_ns.response( + 200, + "Workflow run detail retrieved successfully", + console_ns.models[WorkflowRunDetailResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_detail_model) def get(self, pipeline: Pipeline, run_id): """ Get workflow run detail @@ -819,17 +840,23 @@ class RagPipelineWorkflowRunDetailApi(Resource): rag_pipeline_service = RagPipelineService() workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id) + if workflow_run is None: + raise NotFound("Workflow run not found") - return workflow_run + return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json") @console_ns.route("/rag/pipelines//workflow-runs//node-executions") class RagPipelineWorkflowRunNodeExecutionListApi(Resource): + @console_ns.response( + 200, + "Node executions retrieved successfully", + console_ns.models[WorkflowRunNodeExecutionListResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_list_model) def get(self, pipeline: Pipeline, run_id: str): """ Get workflow run node execution list @@ -844,7 +871,9 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): user=user, ) - return {"data": node_executions} + return WorkflowRunNodeExecutionListResponse.model_validate( + {"data": node_executions}, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/rag/pipelines/datasource-plugins") @@ -859,11 +888,15 @@ class DatasourceListApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/nodes//last-run") class RagPipelineWorkflowLastRunApi(Resource): + @console_ns.response( + 200, + "Node last run retrieved successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_rag_pipeline - @marshal_with(workflow_run_node_execution_model) def get(self, pipeline: Pipeline, node_id: str): rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline) @@ -876,7 +909,7 @@ class RagPipelineWorkflowLastRunApi(Resource): ) if node_exec is None: raise NotFound("last run not found") - return node_exec + return WorkflowRunNodeExecutionResponse.model_validate(node_exec, from_attributes=True).model_dump(mode="json") @console_ns.route("/rag/pipelines/transform/datasets/") @@ -899,12 +932,16 @@ class RagPipelineTransformApi(Resource): @console_ns.route("/rag/pipelines//workflows/draft/datasource/variables-inspect") class RagPipelineDatasourceVariableApi(Resource): @console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__]) + @console_ns.response( + 200, + "Datasource variables set successfully", + console_ns.models[WorkflowRunNodeExecutionResponse.__name__], + ) @setup_required @login_required @account_initialization_required @get_rag_pipeline @edit_permission_required - @marshal_with(workflow_run_node_execution_model) def post(self, pipeline: Pipeline): """ Set datasource variables @@ -918,7 +955,9 @@ class RagPipelineDatasourceVariableApi(Resource): args=args, current_user=current_user, ) - return workflow_node_execution + return WorkflowRunNodeExecutionResponse.model_validate( + workflow_node_execution, from_attributes=True + ).model_dump(mode="json") @console_ns.route("/rag/pipelines/recommended-plugins") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 572f9773a1..5821b91489 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,4 +1,5 @@ from typing import Any +from uuid import UUID from flask import request from flask_restx import Resource @@ -80,7 +81,7 @@ class RecommendedAppListApi(Resource): @account_initialization_required def get(self): # language args - args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) language = args.language if language and language in languages: language_prefix = language @@ -99,6 +100,5 @@ class RecommendedAppListApi(Resource): class RecommendedAppApi(Resource): @login_required @account_initialization_required - def get(self, app_id): - app_id = str(app_id) - return RecommendedAppService.get_recommend_app_detail(app_id) + def get(self, app_id: UUID): + return RecommendedAppService.get_recommend_app_detail(str(app_id)) diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 1456301a24..26b48ec599 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -10,7 +10,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from controllers.common.fields import Parameters as ParametersResponse from controllers.common.fields import Site as SiteResponse -from controllers.common.schema import get_or_create_model +from controllers.common.schema import get_or_create_model, register_schema_models from controllers.console import console_ns from controllers.console.app.error import ( AppUnavailableError, @@ -106,7 +106,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model)) app_detail_fields_with_site_copy["site"] = fields.Nested(site_model) app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy) -simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields) +simple_account_model = get_or_create_model("TrialSimpleAccount", simple_account_fields) conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields) pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields) @@ -120,10 +120,6 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy) -# Pydantic models for request validation -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - class WorkflowRunRequest(BaseModel): inputs: dict files: list | None = None @@ -153,19 +149,7 @@ class CompletionRequest(BaseModel): retriever_from: str = "explore_app" -# Register schemas for Swagger documentation -console_ns.schema_model( - WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -console_ns.schema_model( - CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, WorkflowRunRequest, ChatRequest, TextToSpeechRequest, CompletionRequest) class TrialAppWorkflowRunApi(TrialAppResource): diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 7a6356d052..9ffc18e4c2 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -89,7 +89,7 @@ class CodeBasedExtensionAPI(Resource): @login_required @account_initialization_required def get(self): - query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + query = CodeBasedExtensionQuery.model_validate(request.args.to_dict(flat=True)) return CodeBasedExtensionResponse( module=query.module, diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 109a3cd0d3..9fa5b0f5c1 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -82,7 +82,7 @@ class FileApi(Resource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, source=source, diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index d69a59ecb7..68520e540b 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -52,8 +52,6 @@ from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class AccountInitPayload(BaseModel): interface_language: str @@ -161,27 +159,26 @@ class CheckEmailUniquePayload(BaseModel): email: EmailStr -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(AccountInitPayload) -reg(AccountNamePayload) -reg(AccountAvatarPayload) -reg(AccountAvatarQuery) -reg(AccountInterfaceLanguagePayload) -reg(AccountInterfaceThemePayload) -reg(AccountTimezonePayload) -reg(AccountPasswordPayload) -reg(AccountDeletePayload) -reg(AccountDeletionFeedbackPayload) -reg(EducationActivatePayload) -reg(EducationAutocompleteQuery) -reg(ChangeEmailSendPayload) -reg(ChangeEmailValidityPayload) -reg(ChangeEmailResetPayload) -reg(CheckEmailUniquePayload) -register_schema_models(console_ns, AccountResponse) +register_schema_models( + console_ns, + AccountResponse, + AccountInitPayload, + AccountNamePayload, + AccountAvatarPayload, + AccountAvatarQuery, + AccountInterfaceLanguagePayload, + AccountInterfaceThemePayload, + AccountTimezonePayload, + AccountPasswordPayload, + AccountDeletePayload, + AccountDeletionFeedbackPayload, + EducationActivatePayload, + EducationAutocompleteQuery, + ChangeEmailSendPayload, + ChangeEmailValidityPayload, + ChangeEmailResetPayload, + CheckEmailUniquePayload, +) def _serialize_account(account) -> dict[str, Any]: @@ -326,7 +323,7 @@ class AccountAvatarApi(Resource): @account_initialization_required def get(self): current_user, current_tenant_id = current_account_with_tenant() - args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) avatar = args.avatar if avatar.startswith(("http://", "https://")): diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index d4be07382a..925f3e1197 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -20,8 +20,6 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class EndpointCreatePayload(BaseModel): plugin_unique_identifier: str @@ -80,10 +78,6 @@ class EndpointDisableResponse(BaseModel): success: bool = Field(description="Operation success") -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - register_schema_models( console_ns, EndpointCreatePayload, @@ -215,7 +209,7 @@ class EndpointListApi(Resource): def get(self): user, tenant_id = current_account_with_tenant() - args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page page_size = args.page_size @@ -248,7 +242,7 @@ class EndpointListForSinglePluginApi(Resource): def get(self): user, tenant_id = current_account_with_tenant() - args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) page = args.page page_size = args.page_size diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index ae1b7965d6..c2d3d93d7b 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -34,8 +34,6 @@ from services.enterprise import rbac_service as enterprise_rbac_service from services.errors.account import AccountAlreadyInTenantError from services.feature_service import FeatureService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class MemberInvitePayload(BaseModel): emails: list[str] = Field(default_factory=list) @@ -60,17 +58,17 @@ class OwnerTransferPayload(BaseModel): token: str -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(MemberInvitePayload) -reg(MemberRoleUpdatePayload) -reg(OwnerTransferEmailPayload) -reg(OwnerTransferCheckPayload) -reg(OwnerTransferPayload) register_enum_models(console_ns, TenantAccountRole) -register_schema_models(console_ns, AccountWithRole, AccountWithRoleList) +register_schema_models( + console_ns, + AccountWithRole, + AccountWithRoleList, + MemberInvitePayload, + MemberRoleUpdatePayload, + OwnerTransferEmailPayload, + OwnerTransferCheckPayload, + OwnerTransferPayload, +) def _serialize_member_roles(current_role: str | None, member_roles: list[enterprise_rbac_service.MemberRoleSummary]) -> list[dict[str, str]]: diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 4b10561fdb..2f75218c0f 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -5,6 +5,7 @@ from flask import request, send_file from flask_restx import Resource from pydantic import BaseModel, Field, field_validator +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from graphon.model_runtime.entities.model_entities import ModelType @@ -15,8 +16,6 @@ from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ParserModelList(BaseModel): model_type: ModelType | None = None @@ -75,18 +74,17 @@ class ParserPreferredProviderType(BaseModel): preferred_provider_type: Literal["system", "custom"] -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(ParserModelList) -reg(ParserCredentialId) -reg(ParserCredentialCreate) -reg(ParserCredentialUpdate) -reg(ParserCredentialDelete) -reg(ParserCredentialSwitch) -reg(ParserCredentialValidate) -reg(ParserPreferredProviderType) +register_schema_models( + console_ns, + ParserModelList, + ParserCredentialId, + ParserCredentialCreate, + ParserCredentialUpdate, + ParserCredentialDelete, + ParserCredentialSwitch, + ParserCredentialValidate, + ParserPreferredProviderType, +) @console_ns.route("/workspaces/current/model-providers") diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index b2d07ff8f9..7f7d6379c3 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -17,7 +17,6 @@ from services.model_load_balancing_service import ModelLoadBalancingService from services.model_provider_service import ModelProviderService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class ParserGetDefault(BaseModel): @@ -107,6 +106,12 @@ class ParserParameter(BaseModel): model: str +class ParserSwitch(BaseModel): + model: str + model_type: ModelType + credential_id: str + + register_schema_models( console_ns, ParserGetDefault, @@ -119,6 +124,7 @@ register_schema_models( ParserDeleteCredential, ParserParameter, Inner, + ParserSwitch, ) register_enum_models(console_ns, ModelType) @@ -133,7 +139,7 @@ class DefaultModelApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( @@ -261,7 +267,7 @@ class ModelProviderModelCredentialApi(Resource): def get(self, provider: str): _, tenant_id = current_account_with_tenant() - args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) model_provider_service = ModelProviderService() current_credential = model_provider_service.get_model_credential( @@ -387,17 +393,6 @@ class ModelProviderModelCredentialApi(Resource): return {"result": "success"}, 204 -class ParserSwitch(BaseModel): - model: str - model_type: ModelType - credential_id: str - - -console_ns.schema_model( - ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) - - @console_ns.route("/workspaces/current/model-providers//models/credentials/switch") class ModelProviderModelCredentialSwitchApi(Resource): @console_ns.expect(console_ns.models[ParserSwitch.__name__]) @@ -468,9 +463,7 @@ class ParserValidate(BaseModel): credentials: dict[str, Any] -console_ns.schema_model( - ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(console_ns, ParserSwitch, ParserValidate) @console_ns.route("/workspaces/current/model-providers//models/credentials/validate") @@ -515,7 +508,7 @@ class ModelProviderModelParameterRuleApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserParameter.model_validate(request.args.to_dict(flat=True)) _, tenant_id = current_account_with_tenant() model_provider_service = ModelProviderService() diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index b3e344ccea..a6d4a60beb 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -177,7 +177,7 @@ def _read_upload_content(file: FileStorage, max_size: int) -> bytes: FileStorage.content_length is not reliable for multipart test uploads and may be zero even when content exists, so the controllers validate against the loaded bytes instead. """ - content = file.read() + content = file.stream.read() if len(content) > max_size: raise ValueError("File size exceeds the maximum allowed size") @@ -211,7 +211,7 @@ class PluginListApi(Resource): @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserList.model_validate(request.args.to_dict(flat=True)) try: plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) except PluginDaemonClientSideError as e: @@ -261,7 +261,7 @@ class PluginIconApi(Resource): @console_ns.expect(console_ns.models[ParserIcon.__name__]) @setup_required def get(self): - args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserIcon.model_validate(request.args.to_dict(flat=True)) try: icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename) @@ -279,7 +279,7 @@ class PluginAssetApi(Resource): @login_required @account_initialization_required def get(self): - args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserAsset.model_validate(request.args.to_dict(flat=True)) _, tenant_id = current_account_with_tenant() try: @@ -421,7 +421,7 @@ class PluginFetchMarketplacePkgApi(Resource): @plugin_permission_required(install_required=True) def get(self): _, tenant_id = current_account_with_tenant() - args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) try: return jsonable_encoder( @@ -446,7 +446,7 @@ class PluginFetchManifestApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) try: return jsonable_encoder( @@ -466,7 +466,7 @@ class PluginFetchInstallTasksApi(Resource): def get(self): _, tenant_id = current_account_with_tenant() - args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserTasks.model_validate(request.args.to_dict(flat=True)) try: return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)}) @@ -660,7 +660,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource): current_user, tenant_id = current_account_with_tenant() user_id = current_user.id - args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) try: options = PluginParameterService.get_dynamic_select_options( @@ -822,7 +822,7 @@ class PluginReadmeApi(Resource): @account_initialization_required def get(self): _, tenant_id = current_account_with_tenant() - args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = ParserReadme.model_validate(request.args.to_dict(flat=True)) return jsonable_encoder( {"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)} ) diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 565099db61..84890f0443 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -16,6 +16,7 @@ from controllers.common.errors import ( TooManyFilesError, UnsupportedFileTypeError, ) +from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.admin import admin_required from controllers.console.error import AccountNotLinkTenantError @@ -39,7 +40,6 @@ from services.file_service import FileService from services.workspace_service import WorkspaceService logger = logging.getLogger(__name__) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" class WorkspaceListQuery(BaseModel): @@ -91,15 +91,14 @@ class TenantInfoResponse(ResponseModel): return value -def reg(cls: type[BaseModel]): - console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) - - -reg(WorkspaceListQuery) -reg(SwitchWorkspacePayload) -reg(WorkspaceCustomConfigPayload) -reg(WorkspaceInfoPayload) -reg(TenantInfoResponse) +register_schema_models( + console_ns, + WorkspaceListQuery, + SwitchWorkspacePayload, + WorkspaceCustomConfigPayload, + WorkspaceInfoPayload, + TenantInfoResponse, +) provider_fields = { "provider_name": fields.String, @@ -322,7 +321,7 @@ class WebappLogoWorkspaceApi(Resource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, ) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index a91e745f80..be7886e831 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -8,13 +8,12 @@ from werkzeug.exceptions import NotFound import services from controllers.common.errors import UnsupportedFileTypeError from controllers.common.file_response import enforce_download_for_html +from controllers.common.schema import register_schema_models from controllers.files import files_ns from extensions.ext_database import db from services.account_service import TenantService from services.file_service import FileService -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class FileSignatureQuery(BaseModel): timestamp: str = Field(..., description="Unix timestamp used in the signature") @@ -26,12 +25,7 @@ class FilePreviewQuery(FileSignatureQuery): as_attachment: bool = Field(default=False, description="Whether to download as attachment") -files_ns.schema_model( - FileSignatureQuery.__name__, FileSignatureQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) -files_ns.schema_model( - FilePreviewQuery.__name__, FilePreviewQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(files_ns, FileSignatureQuery, FilePreviewQuery) @files_ns.route("//image-preview") @@ -58,7 +52,7 @@ class ImagePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True)) timestamp = args.timestamp nonce = args.nonce sign = args.sign @@ -100,7 +94,7 @@ class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True)) try: generator, upload_file = FileService(db.engine).get_file_generator_by_file_id( diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 2f1e2f28bd..8ae16ce7f4 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -7,12 +7,11 @@ from werkzeug.exceptions import Forbidden, NotFound from controllers.common.errors import UnsupportedFileTypeError from controllers.common.file_response import enforce_download_for_html +from controllers.common.schema import register_schema_models from controllers.files import files_ns from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class ToolFileQuery(BaseModel): timestamp: str = Field(..., description="Unix timestamp") @@ -21,9 +20,7 @@ class ToolFileQuery(BaseModel): as_attachment: bool = Field(default=False, description="Download as attachment") -files_ns.schema_model( - ToolFileQuery.__name__, ToolFileQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(files_ns, ToolFileQuery) @files_ns.route("/tools/.") diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index ed3278a28b..7d588b95dd 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -20,8 +20,6 @@ from ..console.wraps import setup_required from ..files import files_ns from ..inner_api.plugin.wraps import get_user -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - class PluginUploadQuery(BaseModel): timestamp: str = Field(..., description="Unix timestamp for signature verification") @@ -31,9 +29,8 @@ class PluginUploadQuery(BaseModel): user_id: str | None = Field(default=None, description="User identifier") -files_ns.schema_model( - PluginUploadQuery.__name__, PluginUploadQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) -) +register_schema_models(files_ns, PluginUploadQuery) + register_schema_models(files_ns, FileResponse) @@ -69,7 +66,7 @@ class PluginUploadFileApi(Resource): FileTooLargeError: File exceeds size limit UnsupportedFileTypeError: File type not supported """ - args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore + args = PluginUploadQuery.model_validate(request.args.to_dict(flat=True)) file = request.files.get("file") if file is None: @@ -103,7 +100,7 @@ class PluginUploadFileApi(Resource): tool_file = ToolFileManager().create_file_by_raw( user_id=user.id, tenant_id=tenant_id, - file_binary=file.read(), + file_binary=file.stream.read(), mimetype=mimetype, filename=filename, conversation_id=None, diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 6f6dadf768..687d34076d 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -58,7 +58,7 @@ class FileApi(Resource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=end_user, ) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 3eb773fa7c..9af66f1960 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_valid from werkzeug.exceptions import Forbidden, NotFound import services -from controllers.common.schema import register_schema_models +from controllers.common.schema import register_enum_models, register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError @@ -34,13 +34,7 @@ from services.tag_service import ( UpdateTagPayload, ) -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - -service_api_ns.schema_model( - DatasetPermissionEnum.__name__, - TypeAdapter(DatasetPermissionEnum).json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), -) +register_enum_models(service_api_ns, DatasetPermissionEnum) class DatasetCreatePayload(BaseModel): diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 0b09facf58..e68eeeca25 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -77,9 +77,6 @@ class DocumentTextCreatePayload(BaseModel): return value -DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" - - class DocumentTextUpdate(BaseModel): name: str | None = None text: str | None = None @@ -139,7 +136,7 @@ def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[ if not dataset: raise ValueError("Dataset does not exist.") - if not dataset.indexing_technique and not args["indexing_technique"]: + if not dataset.indexing_technique and not args.get("indexing_technique"): raise ValueError("indexing_technique is required.") embedding_model_provider = payload.embedding_model_provider @@ -435,7 +432,7 @@ class DocumentAddByFileApi(DatasetApiResource): raise ValueError("current_user is required") upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, source="datasets", @@ -509,7 +506,7 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, source="datasets", diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 2dc98bfbf7..8bc43bccd5 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -241,7 +241,7 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=current_user, ) diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 0036c90800..6128490104 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -73,7 +73,7 @@ class FileApi(WebApiResource): try: upload_file = FileService(db.engine).upload_file( filename=file.filename, - content=file.read(), + content=file.stream.read(), mimetype=file.mimetype, user=end_user, source="datasets" if source == "datasets" else None, diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index c22102c2ba..cba4659483 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -532,7 +532,6 @@ class BaseAgentRunner(AppRunner): file_objs = file_factory.build_from_message_files( message_files=files, tenant_id=self.tenant_id, - config=file_extra_config, access_controller=_file_access_controller, ) if not file_objs: diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 4c07445df3..f4bbbe5d8b 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -75,7 +75,7 @@ class PromptTemplateConfigManager: if not config.get("prompt_type"): config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE - prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] + prompt_type_vals = list(PromptTemplateEntity.PromptType) if config["prompt_type"] not in prompt_type_vals: raise ValueError(f"prompt_type must be in {prompt_type_vals}") diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 3282c9a334..13d362a0bd 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -32,7 +32,7 @@ from core.app.entities.task_entities import ( ) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.db.session_factory import session_factory -from core.helper.trace_id_helper import extract_external_trace_id_from_args +from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_parent_trace_context_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository @@ -185,6 +185,7 @@ class WorkflowAppGenerator(BaseAppGenerator): extras = { **extract_external_trace_id_from_args(args), + **extract_parent_trace_context_from_args(args), } workflow_run_id = str(workflow_run_id or uuid.uuid4()) # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args diff --git a/api/core/app/llm/__init__.py b/api/core/app/llm/__init__.py index f069bede74..d20a5b2344 100644 --- a/api/core/app/llm/__init__.py +++ b/api/core/app/llm/__init__.py @@ -1,5 +1,15 @@ """LLM-related application services.""" -from .quota import deduct_llm_quota, ensure_llm_quota_available +from .quota import ( + deduct_llm_quota, + deduct_llm_quota_for_model, + ensure_llm_quota_available, + ensure_llm_quota_available_for_model, +) -__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"] +__all__ = [ + "deduct_llm_quota", + "deduct_llm_quota_for_model", + "ensure_llm_quota_available", + "ensure_llm_quota_available_for_model", +] diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index b6039e1e4e..5bf3334a7b 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,4 +1,14 @@ -from sqlalchemy import update +"""Tenant-scoped helpers for checking and deducting LLM provider quota. + +System-hosted quota accounting is currently defined only for LLM models. Keep +the public helpers LLM-specific so callers do not carry unused model-type +plumbing, and fail loudly if the deprecated ``ModelInstance`` wrappers are used +with a non-LLM model. +""" + +import warnings + +from sqlalchemy import select from sqlalchemy.orm import sessionmaker from configs import dify_config @@ -6,44 +16,47 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from extensions.ext_database import db from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID -def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None: - provider_model_bundle = model_instance.provider_model_bundle - provider_configuration = provider_model_bundle.configuration +def _get_provider_configuration(*, tenant_id: str, provider: str): + """Resolve the tenant-bound provider configuration for quota decisions.""" + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) + provider_configuration = provider_manager.get_configurations(tenant_id).get(provider) + if provider_configuration is None: + raise ValueError(f"Provider {provider} does not exist.") + return provider_configuration + +def ensure_llm_quota_available_for_model(*, tenant_id: str, provider: str, model: str) -> None: + """Raise when a tenant-bound LLM model is already out of quota.""" + provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider) if provider_configuration.using_provider_type != ProviderType.SYSTEM: return provider_model = provider_configuration.get_provider_model( - model_type=model_instance.model_type_instance.model_type, - model=model_instance.model_name, + model_type=ModelType.LLM, + model=model, ) if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.") + raise QuotaExceededError(f"Model provider {provider} quota exceeded.") -def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: - provider_model_bundle = model_instance.provider_model_bundle - provider_configuration = provider_model_bundle.configuration - - if provider_configuration.using_provider_type != ProviderType.SYSTEM: - return - - system_configuration = provider_configuration.system_configuration - +def _resolve_llm_used_quota(*, system_configuration, model: str, usage: LLMUsage) -> int | None: + """Compute the quota impact for an LLM invocation under the current quota mode.""" quota_unit = None for quota_configuration in system_configuration.quota_configurations: if quota_configuration.quota_type == system_configuration.current_quota_type: quota_unit = quota_configuration.quota_unit if quota_configuration.quota_limit == -1: - return + return None break @@ -52,42 +65,136 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL if quota_unit == QuotaUnit.TOKENS: used_quota = usage.total_tokens elif quota_unit == QuotaUnit.CREDITS: - used_quota = dify_config.get_model_credits(model_instance.model_name) + used_quota = dify_config.get_model_credits(model) else: used_quota = 1 + return used_quota + + +def _deduct_free_llm_quota( + *, + tenant_id: str, + provider: str, + quota_type: ProviderQuotaType, + used_quota: int, +) -> None: + """Deduct FREE provider quota, capping at the limit before reporting exhaustion.""" + quota_exceeded = False + with sessionmaker(bind=db.engine).begin() as session: + provider_record = session.scalar( + select(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == quota_type, + ) + .with_for_update() + ) + if ( + provider_record is None + or provider_record.quota_limit is None + or provider_record.quota_used is None + or provider_record.quota_limit <= provider_record.quota_used + ): + quota_exceeded = True + else: + available_quota = provider_record.quota_limit - provider_record.quota_used + deducted_quota = min(used_quota, available_quota) + provider_record.quota_used += deducted_quota + provider_record.last_used = naive_utc_now() + quota_exceeded = deducted_quota < used_quota + + if quota_exceeded: + raise QuotaExceededError(f"Model provider {provider} quota exceeded.") + + +def _deduct_used_llm_quota(*, tenant_id: str, provider: str, provider_configuration, used_quota: int | None) -> None: + """Apply a resolved LLM quota charge against the current provider quota bucket.""" + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration if used_quota is not None and system_configuration.current_quota_type is not None: match system_configuration.current_quota_type: case ProviderQuotaType.TRIAL: from services.credit_pool_service import CreditPoolService - CreditPoolService.check_and_deduct_credits( + CreditPoolService.deduct_credits_capped( tenant_id=tenant_id, credits_required=used_quota, ) case ProviderQuotaType.PAID: from services.credit_pool_service import CreditPoolService - CreditPoolService.check_and_deduct_credits( + CreditPoolService.deduct_credits_capped( tenant_id=tenant_id, credits_required=used_quota, pool_type="paid", ) case ProviderQuotaType.FREE: - with sessionmaker(bind=db.engine).begin() as session: - stmt = ( - update(Provider) - .where( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type, - Provider.quota_limit > Provider.quota_used, - ) - .values( - quota_used=Provider.quota_used + used_quota, - last_used=naive_utc_now(), - ) - ) - session.execute(stmt) + _deduct_free_llm_quota( + tenant_id=tenant_id, + provider=provider, + quota_type=system_configuration.current_quota_type, + used_quota=used_quota, + ) + case _: + return + + +def deduct_llm_quota_for_model(*, tenant_id: str, provider: str, model: str, usage: LLMUsage) -> None: + """Deduct tenant-bound quota for the resolved LLM model identity.""" + provider_configuration = _get_provider_configuration(tenant_id=tenant_id, provider=provider) + used_quota = _resolve_llm_used_quota( + system_configuration=provider_configuration.system_configuration, + model=model, + usage=usage, + ) + _deduct_used_llm_quota( + tenant_id=tenant_id, + provider=provider, + provider_configuration=provider_configuration, + used_quota=used_quota, + ) + + +def _require_llm_model_instance(model_instance: ModelInstance) -> None: + """Reject deprecated wrapper calls that pass a non-LLM model instance.""" + if model_instance.model_type_instance.model_type != ModelType.LLM: + raise ValueError("LLM quota helpers only support LLM model instances.") + + +def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None: + """Deprecated compatibility wrapper for callers that still pass ModelInstance.""" + warnings.warn( + "ensure_llm_quota_available(model_instance=...) is deprecated; " + "use ensure_llm_quota_available_for_model(...) instead.", + DeprecationWarning, + stacklevel=2, + ) + _require_llm_model_instance(model_instance) + ensure_llm_quota_available_for_model( + tenant_id=model_instance.provider_model_bundle.configuration.tenant_id, + provider=model_instance.provider, + model=model_instance.model_name, + ) + + +def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: + """Deprecated compatibility wrapper for callers that still pass ModelInstance.""" + warnings.warn( + "deduct_llm_quota(tenant_id=..., model_instance=..., usage=...) is deprecated; " + "use deduct_llm_quota_for_model(...) instead.", + DeprecationWarning, + stacklevel=2, + ) + _require_llm_model_instance(model_instance) + deduct_llm_quota_for_model( + tenant_id=tenant_id, + provider=model_instance.provider, + model=model_instance.model_name, + usage=usage, + ) diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 3a6f9d575a..587f700286 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -128,7 +128,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): @staticmethod def _secret_key() -> bytes: - return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + return dify_config.SECRET_KEY.encode() def _sign_query(self, *, payload: str) -> dict[str, str]: timestamp = str(int(time.time())) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 4a7918032e..2422eed5a7 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -1,36 +1,48 @@ """ LLM quota deduction layer for GraphEngine. -This layer centralizes model-quota deduction outside node implementations. +This layer centralizes model-quota handling outside node implementations. + +Graphon LLM-backed nodes expose provider/model identity through public node +configuration and, after execution, through ``node_run_result.inputs``. Resolve +quota billing from that public identity instead of depending on +``ModelInstance`` reconstruction inside the workflow layer. Missing identity on +quota-tracked nodes is treated as a workflow bug and aborts execution so quota +handling is never silently skipped. """ import logging -from typing import TYPE_CHECKING, cast, final, override +from typing import final, override -from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext -from core.app.llm import deduct_llm_quota, ensure_llm_quota_available +from core.app.llm import deduct_llm_quota_for_model, ensure_llm_quota_available_for_model from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance -from graphon.enums import BuiltinNodeTypes +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.entities.commands import AbortCommand, CommandType from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult from graphon.nodes.base.node import Node -if TYPE_CHECKING: - from graphon.nodes.llm.node import LLMNode - from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode - from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode - logger = logging.getLogger(__name__) +_QUOTA_NODE_TYPES = frozenset( + [ + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + ] +) @final class LLMQuotaLayer(GraphEngineLayer): - """Graph layer that applies LLM quota deduction after node execution.""" + """Graph layer that applies tenant-scoped quota checks to LLM-backed nodes.""" - def __init__(self) -> None: + tenant_id: str + _abort_sent: bool + + def __init__(self, tenant_id: str) -> None: super().__init__() + self.tenant_id = tenant_id self._abort_sent = False @override @@ -50,33 +62,49 @@ class LLMQuotaLayer(GraphEngineLayer): if self._abort_sent: return - model_instance = self._extract_model_instance(node) - if model_instance is None: + if not self._supports_quota(node): return + model_identity = self._extract_model_identity_from_node(node) + if model_identity is None: + reason = "LLM quota check requires public node model identity before execution." + self._abort_before_node_run(node=node, reason=reason, error_type="LLMQuotaIdentityError") + logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason) + return + + provider, model_name = model_identity try: - ensure_llm_quota_available(model_instance=model_instance) + ensure_llm_quota_available_for_model( + tenant_id=self.tenant_id, + provider=provider, + model=model_name, + ) except QuotaExceededError as exc: - self._set_stop_event(node) - self._send_abort_command(reason=str(exc)) + self._abort_before_node_run(node=node, reason=str(exc), error_type=QuotaExceededError.__name__) logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc) @override def on_node_run_end( self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None ) -> None: - if error is not None or not isinstance(result_event, NodeRunSucceededEvent): + if error is not None or not isinstance(result_event, NodeRunSucceededEvent) or not self._supports_quota(node): return - model_instance = self._extract_model_instance(node) - if model_instance is None: + model_identity = self._extract_model_identity_from_result_event(result_event) + if model_identity is None: + self._abort_for_missing_model_identity( + node=node, + reason="LLM quota deduction requires model identity in the node result event.", + ) return + provider, model_name = model_identity + try: - dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) - deduct_llm_quota( - tenant_id=dify_ctx.tenant_id, - model_instance=model_instance, + deduct_llm_quota_for_model( + tenant_id=self.tenant_id, + provider=provider, + model=model_name, usage=result_event.node_run_result.llm_usage, ) except QuotaExceededError as exc: @@ -92,6 +120,27 @@ class LLMQuotaLayer(GraphEngineLayer): if stop_event is not None: stop_event.set() + def _abort_before_node_run(self, *, node: Node, reason: str, error_type: str) -> None: + self._set_stop_event(node) + node.node_data.error_strategy = None + node.node_data.retry_config.retry_enabled = False + + def quota_aborted_run() -> NodeRunResult: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=reason, + error_type=error_type, + ) + + # TODO: Push Graphon to expose a public pre-run failure/skip hook, then replace this private _run override. + node._run = quota_aborted_run # type: ignore[method-assign] + self._send_abort_command(reason=reason) + + def _abort_for_missing_model_identity(self, *, node: Node, reason: str) -> None: + self._set_stop_event(node) + self._send_abort_command(reason=reason) + logger.error("LLM quota handling aborted, node_id=%s, reason=%s", node.id, reason) + def _send_abort_command(self, *, reason: str) -> None: if not self.command_channel or self._abort_sent: return @@ -108,29 +157,38 @@ class LLMQuotaLayer(GraphEngineLayer): logger.exception("Failed to send quota abort command") @staticmethod - def _extract_model_instance(node: Node) -> ModelInstance | None: - try: - match node.node_type: - case BuiltinNodeTypes.LLM: - model_instance = cast("LLMNode", node).model_instance - case BuiltinNodeTypes.PARAMETER_EXTRACTOR: - model_instance = cast("ParameterExtractorNode", node).model_instance - case BuiltinNodeTypes.QUESTION_CLASSIFIER: - model_instance = cast("QuestionClassifierNode", node).model_instance - case _: - return None - except AttributeError: + def _supports_quota(node: Node) -> bool: + return node.node_type in _QUOTA_NODE_TYPES + + @staticmethod + def _extract_model_identity_from_result_event(result_event: NodeRunSucceededEvent) -> tuple[str, str] | None: + provider = result_event.node_run_result.inputs.get("model_provider") + model_name = result_event.node_run_result.inputs.get("model_name") + if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name: + return provider, model_name + return None + + @staticmethod + def _extract_model_identity_from_node(node: Node) -> tuple[str, str] | None: + node_data = getattr(node, "node_data", None) + if node_data is None: + node_data = getattr(node, "data", None) + + model_config = getattr(node_data, "model", None) + if model_config is None: logger.warning( - "LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s", + "LLMQuotaLayer skipped quota handling because node model config is missing, node_id=%s", node.id, ) return None - if isinstance(model_instance, ModelInstance): - return model_instance - - raw_model_instance = getattr(model_instance, "_model_instance", None) - if isinstance(raw_model_instance, ModelInstance): - return raw_model_instance + provider = getattr(model_config, "provider", None) + model_name = getattr(model_config, "name", None) + if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name: + return provider, model_name + logger.warning( + "LLMQuotaLayer skipped quota handling because node model identity is invalid, node_id=%s", + node.id, + ) return None diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index d521304615..19152cebae 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -15,6 +15,7 @@ from datetime import datetime from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.helper.trace_id_helper import ParentTraceContext from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository @@ -403,8 +404,13 @@ class WorkflowPersistenceLayer(GraphEngineLayer): conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value) external_trace_id = None + parent_trace_context = None if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)): - external_trace_id = self._application_generate_entity.extras.get("external_trace_id") + extras = self._application_generate_entity.extras + external_trace_id = extras.get("external_trace_id") + parent_trace_context = extras.get("parent_trace_context") + if isinstance(parent_trace_context, ParentTraceContext): + parent_trace_context = parent_trace_context.model_dump(exclude_none=True) trace_task = TraceTask( TraceTaskName.WORKFLOW_TRACE, @@ -412,6 +418,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): conversation_id=conversation_id, user_id=self._trace_manager.user_id, external_trace_id=external_trace_id, + parent_trace_context=parent_trace_context, ) self._trace_manager.add_trace_task(trace_task) diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 492b507aa9..79b84a28be 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -35,8 +35,11 @@ class DatasourceFileManager: timestamp = str(int(time.time())) nonce = os.urandom(16).hex() data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + sign = hmac.new( + dify_config.SECRET_KEY.encode(), + data_to_sign.encode(), + hashlib.sha256, + ).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" @@ -47,8 +50,11 @@ class DatasourceFileManager: verify signature """ data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_sign = hmac.new( + dify_config.SECRET_KEY.encode(), + data_to_sign.encode(), + hashlib.sha256, + ).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() # verify signature diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 38b87e2cd1..495fd1d898 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -23,7 +23,7 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_assembly from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( ConfigurateMethod, @@ -33,7 +33,7 @@ from graphon.model_runtime.entities.provider_entities import ( ) from graphon.model_runtime.model_providers.base.ai_model import AIModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.protocols.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -106,11 +106,18 @@ class ProviderConfiguration(BaseModel): """Attach the already-composed runtime for request-bound call chains.""" self._bound_model_runtime = model_runtime + def _get_runtime_and_provider_factory(self) -> tuple[ModelRuntime, ModelProviderFactory]: + """Resolve a provider factory that stays aligned with the runtime used by the caller.""" + if self._bound_model_runtime is not None: + return self._bound_model_runtime, ModelProviderFactory(runtime=self._bound_model_runtime) + + model_assembly = create_plugin_model_assembly(tenant_id=self.tenant_id) + return model_assembly.model_runtime, model_assembly.model_provider_factory + def get_model_provider_factory(self) -> ModelProviderFactory: """Return a provider factory that preserves any request-bound runtime.""" - if self._bound_model_runtime is not None: - return ModelProviderFactory(model_runtime=self._bound_model_runtime) - return create_plugin_model_provider_factory(tenant_id=self.tenant_id) + _, model_provider_factory = self._get_runtime_and_provider_factory() + return model_provider_factory def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None: """ @@ -1392,10 +1399,13 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = self.get_model_provider_factory() - - # Get model instance of LLM - return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) + model_runtime, model_provider_factory = self._get_runtime_and_provider_factory() + provider_schema = model_provider_factory.get_provider_schema(provider=self.provider.provider) + return create_model_type_instance( + runtime=model_runtime, + provider_schema=provider_schema, + model_type=model_type, + ) def get_model_schema( self, model_type: ModelType, model: str, credentials: dict[str, Any] | None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index f169f247cf..18b9b72e9d 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -4,7 +4,7 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID -from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from extensions.ext_hosting_provider import hosting_configuration from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.invoke import InvokeBadRequestError @@ -41,10 +41,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt text_chunk = secrets.choice(text_chunks) try: - model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) - - # Get model instance of LLM - model_type_instance = model_provider_factory.get_model_type_instance( + model_assembly = create_plugin_model_assembly(tenant_id=tenant_id) + model_type_instance = model_assembly.create_model_type_instance( provider=openai_provider_name, model_type=ModelType.MODERATION ) model_type_instance = cast(ModerationModel, model_type_instance) diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py index e827859109..e4890c8d4d 100644 --- a/api/core/helper/trace_id_helper.py +++ b/api/core/helper/trace_id_helper.py @@ -3,6 +3,17 @@ import re from collections.abc import Mapping from typing import Any +from pydantic import BaseModel, ConfigDict, StrictStr, ValidationError + + +class ParentTraceContext(BaseModel): + """Typed parent trace context propagated from an outer workflow tool node.""" + + parent_workflow_run_id: StrictStr + parent_node_execution_id: StrictStr | None = None + + model_config = ConfigDict(extra="forbid") + def is_valid_trace_id(trace_id: str) -> bool: """ @@ -61,6 +72,30 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]): return {} +def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str, ParentTraceContext]: + """ + Extract 'parent_trace_context' from args. + + Returns a dict suitable for use in extras when both parent identifiers exist. + Returns an empty dict if the context is missing or incomplete. + """ + parent_trace_context = args.get("parent_trace_context") + if isinstance(parent_trace_context, ParentTraceContext): + context = parent_trace_context + elif isinstance(parent_trace_context, Mapping): + try: + context = ParentTraceContext.model_validate(parent_trace_context) + except ValidationError: + return {} + else: + return {} + + if context.parent_node_execution_id is None: + return {} + + return {"parent_trace_context": context} + + def get_trace_id_from_otel_context() -> str | None: """ Retrieve the current trace ID from the active OpenTelemetry trace context. diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index b6e33396d1..537b14388e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -324,9 +324,10 @@ class IndexingRunner: # one extract_setting is one source document for extract_setting in extract_settings: # extract - processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) - ) + processing_rule = { + "mode": tmp_processing_rule["mode"], + "rules": tmp_processing_rule.get("rules"), + } # Extract document content text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) # Cleaning and segmentation @@ -334,7 +335,7 @@ class IndexingRunner: text_docs, current_user=None, embedding_model_instance=embedding_model_instance, - process_rule=processing_rule.to_dict(), + process_rule=processing_rule, tenant_id=tenant_id, doc_language=doc_language, preview=True, diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index d840ee213c..c41c175cca 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -86,12 +86,10 @@ class TokenBufferMemory: detail = ImagePromptMessageContent.DETAIL.HIGH if file_extra_config and app_record: - # Build files directly without filtering by belongs_to file_objs = [ file_factory.build_from_message_file( message_file=message_file, tenant_id=app_record.tenant_id, - config=file_extra_config, access_controller=_file_access_controller, ) for message_file in message_files diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 45b2f635ba..98e87a0ceb 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -5,6 +5,8 @@ from typing import Any, Union from pydantic import BaseModel, ConfigDict, field_serializer, field_validator +from core.helper.trace_id_helper import ParentTraceContext + class BaseTraceInfo(BaseModel): message_id: str | None = None @@ -51,8 +53,8 @@ class BaseTraceInfo(BaseModel): def resolved_parent_context(self) -> tuple[str | None, str | None]: """Resolve cross-workflow parent linking from metadata. - Extracts typed parent IDs from the untyped ``parent_trace_context`` - metadata dict (set by tool_node when invoking nested workflows). + Extracts typed parent IDs from the ``parent_trace_context`` metadata + payload (set by tool_node when invoking nested workflows). Returns: (trace_correlation_override, parent_span_id_source) where @@ -60,13 +62,18 @@ class BaseTraceInfo(BaseModel): parent_span_id_source is the outer node_execution_id. """ parent_ctx = self.metadata.get("parent_trace_context") - if not isinstance(parent_ctx, dict): + if isinstance(parent_ctx, ParentTraceContext): + context = parent_ctx + elif isinstance(parent_ctx, Mapping): + try: + context = ParentTraceContext.model_validate(parent_ctx) + except ValueError: + return None, None + else: return None, None - trace_override = parent_ctx.get("parent_workflow_run_id") - parent_span = parent_ctx.get("parent_node_execution_id") return ( - trace_override if isinstance(trace_override, str) else None, - parent_span if isinstance(parent_span, str) else None, + context.parent_workflow_run_id, + context.parent_node_execution_id, ) @field_serializer("start_time", "end_time") diff --git a/api/core/ops/exceptions.py b/api/core/ops/exceptions.py new file mode 100644 index 0000000000..4551704687 --- /dev/null +++ b/api/core/ops/exceptions.py @@ -0,0 +1,22 @@ +"""Core exceptions shared by ops trace dispatchers and trace providers. + +Provider packages may raise these types to request generic task behavior, but +generic Celery tasks should not import provider-specific exception classes. +""" + + +class RetryableTraceDispatchError(RuntimeError): + """Base class for transient trace dispatch failures that Celery may retry.""" + + +class PendingTraceParentContextError(RetryableTraceDispatchError): + """Raised when a nested trace arrives before its parent span context is available.""" + + parent_node_execution_id: str + + def __init__(self, parent_node_execution_id: str) -> None: + self.parent_node_execution_id = parent_node_execution_id + super().__init__( + "Pending trace parent context for parent_node_execution_id=" + f"{parent_node_execution_id}. Retry after the parent span context is published." + ) diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index bae0016744..61fd0e5c1f 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -16,6 +16,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token +from core.helper.trace_id_helper import ParentTraceContext from core.ops.entities.config_entity import ( OPS_FILE_PATH, BaseTracingConfig, @@ -52,6 +53,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _dump_parent_trace_context(parent_trace_context: Any) -> dict[str, str] | None: + if isinstance(parent_trace_context, ParentTraceContext): + return parent_trace_context.model_dump(exclude_none=True) + if isinstance(parent_trace_context, dict): + try: + return ParentTraceContext.model_validate(parent_trace_context).model_dump(exclude_none=True) + except ValueError: + return None + return None + + class _AppTracingConfig(TypedDict, total=False): enabled: bool tracing_provider: str | None @@ -857,8 +869,9 @@ class TraceTask: } parent_trace_context = self.kwargs.get("parent_trace_context") - if parent_trace_context: - metadata["parent_trace_context"] = parent_trace_context + dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context) + if dumped_parent_trace_context: + metadata["parent_trace_context"] = dumped_parent_trace_context workflow_trace_info = WorkflowTraceInfo( trace_id=self.trace_id, @@ -1371,13 +1384,14 @@ class TraceTask: } parent_trace_context = node_data.get("parent_trace_context") - if parent_trace_context: - metadata["parent_trace_context"] = parent_trace_context + dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context) + if dumped_parent_trace_context: + metadata["parent_trace_context"] = dumped_parent_trace_context message_id: str | None = None conversation_id = node_data.get("conversation_id") workflow_execution_id = node_data.get("workflow_execution_id") - if conversation_id and workflow_execution_id and not parent_trace_context: + if conversation_id and workflow_execution_id and not dumped_parent_trace_context: with Session(db.engine) as session: msg_id = session.scalar( select(Message.id).where( diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 4e66d58b5e..62573ba2f5 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -4,23 +4,32 @@ import hashlib import logging from collections.abc import Generator, Iterable, Sequence from threading import Lock -from typing import IO, Any, Union +from typing import IO, Any, Literal, cast, overload from pydantic import ValidationError from redis import RedisError from configs import dify_config +from core.llm_generator.output_parser.structured_output import ( + invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper, +) from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.model import PluginModelClient from extensions.ext_redis import redis_client -from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType from graphon.model_runtime.entities.provider_entities import ProviderEntity from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult -from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.model_providers.base.large_language_model import normalize_non_stream_runtime_result +from graphon.model_runtime.protocols.runtime import ModelRuntime from models.provider_ids import ModelProviderID logger = logging.getLogger(__name__) @@ -29,6 +38,68 @@ logger = logging.getLogger(__name__) TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__" +# TODO(-LAN-): Move native structured-output invocation into Graphon's LLM node. +# TODO(-LAN-): Remove this Dify-side adapter once Graphon owns structured output end-to-end. +class _PluginStructuredOutputModelInstance: + """Bind plugin model identity to the shared structured-output helper. + + The structured-output parser is shared with legacy ``ModelInstance`` flows + and only needs an object exposing ``invoke_llm(...)``. ``PluginModelRuntime`` + intentionally exposes a lower-level API where provider, model, and + credentials are passed per call. This adapter supplies the small bound + ``invoke_llm`` surface the helper needs without constructing a full + ``ModelInstance`` or reintroducing model-manager dependencies into the + plugin runtime path. + """ + + def __init__( + self, + *, + runtime: PluginModelRuntime, + provider: str, + model: str, + credentials: dict[str, Any], + ) -> None: + self._runtime = runtime + self._provider = provider + self._model = model + self._credentials = credentials + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: dict[str, Any] | None = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, + stream: bool = True, + callbacks: object | None = None, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: + del callbacks + if stream: + return self._runtime.invoke_llm( + provider=self._provider, + model=self._model, + credentials=self._credentials, + model_parameters=model_parameters or {}, + prompt_messages=prompt_messages, + tools=list(tools) if tools else None, + stop=stop, + stream=True, + ) + + return self._runtime.invoke_llm( + provider=self._provider, + model=self._model, + credentials=self._credentials, + model_parameters=model_parameters or {}, + prompt_messages=prompt_messages, + tools=list(tools) if tools else None, + stop=stop, + stream=False, + ) + + class PluginModelRuntime(ModelRuntime): """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" @@ -195,6 +266,34 @@ class PluginModelRuntime(ModelRuntime): return schema + @overload + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunk, None, None]: ... + def invoke_llm( self, *, @@ -206,9 +305,9 @@ class PluginModelRuntime(ModelRuntime): tools: list[PromptMessageTool] | None, stop: Sequence[str] | None, stream: bool, - ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + ) -> LLMResult | Generator[LLMResultChunk, None, None]: plugin_id, provider_name = self._split_provider(provider) - return self.client.invoke_llm( + result = self.client.invoke_llm( tenant_id=self.tenant_id, user_id=self.user_id, plugin_id=plugin_id, @@ -221,6 +320,81 @@ class PluginModelRuntime(ModelRuntime): stop=list(stop) if stop else None, stream=stream, ) + if stream: + return result + + return normalize_non_stream_runtime_result( + model=model, + prompt_messages=prompt_messages, + result=result, + ) + + @overload + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResultWithStructuredOutput: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: + model_schema = self.get_model_schema( + provider=provider, + model_type=ModelType.LLM, + model=model, + credentials=credentials, + ) + if model_schema is None: + raise ValueError(f"Model schema not found for {model}") + + adapter = _PluginStructuredOutputModelInstance( + runtime=self, + provider=provider, + model=model, + credentials=credentials, + ) + return invoke_llm_with_structured_output_helper( + provider=provider, + model_schema=model_schema, + model_instance=cast(Any, adapter), + prompt_messages=prompt_messages, + json_schema=json_schema, + model_parameters=model_parameters, + tools=None, + stop=list(stop) if stop else None, + stream=stream, + ) def get_llm_num_tokens( self, diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py index 35abd2ae8c..fbe307ea60 100644 --- a/api/core/plugin/impl/model_runtime_factory.py +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -3,13 +3,46 @@ from __future__ import annotations from typing import TYPE_CHECKING from core.plugin.impl.model import PluginModelClient +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.base.tts_model import TTSModel from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.protocols.runtime import ModelRuntime if TYPE_CHECKING: from core.model_manager import ModelManager from core.plugin.impl.model_runtime import PluginModelRuntime from core.provider_manager import ProviderManager +_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = { + ModelType.LLM: LargeLanguageModel, + ModelType.TEXT_EMBEDDING: TextEmbeddingModel, + ModelType.RERANK: RerankModel, + ModelType.SPEECH2TEXT: Speech2TextModel, + ModelType.MODERATION: ModerationModel, + ModelType.TTS: TTSModel, +} + + +def create_model_type_instance( + *, + runtime: ModelRuntime, + provider_schema: ProviderEntity, + model_type: ModelType, +) -> AIModel: + """Build the graphon model wrapper explicitly against the request runtime.""" + model_class = _MODEL_CLASS_BY_TYPE.get(model_type) + if model_class is None: + raise ValueError(f"Unsupported model type: {model_type}") + + return model_class(provider_schema=provider_schema, model_runtime=runtime) + class PluginModelAssembly: """Compose request-scoped model views on top of a single plugin runtime.""" @@ -38,9 +71,22 @@ class PluginModelAssembly: @property def model_provider_factory(self) -> ModelProviderFactory: if self._model_provider_factory is None: - self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime) + self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime) return self._model_provider_factory + def create_model_type_instance( + self, + *, + provider: str, + model_type: ModelType, + ) -> AIModel: + provider_schema = self.model_provider_factory.get_provider_schema(provider=provider) + return create_model_type_instance( + runtime=self.model_runtime, + provider_schema=provider_schema, + model_type=model_type, + ) + @property def provider_manager(self) -> ProviderManager: if self._provider_manager is None: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 1665bdeb52..e836554ca0 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -123,12 +123,15 @@ class SimplePromptTransform(PromptTransform): for v in special_variable_keys: # support #context#, #query# and #histories# - if v == "#context#": - variables["#context#"] = context or "" - elif v == "#query#": - variables["#query#"] = query or "" - elif v == "#histories#": - variables["#histories#"] = histories or "" + match v: + case "#context#": + variables["#context#"] = context or "" + case "#query#": + variables["#query#"] = query or "" + case "#histories#": + variables["#histories#"] = histories or "" + case _: + pass prompt_template = prompt_template_config["prompt_template"] if not isinstance(prompt_template, PromptTemplateParser): diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index b290ae456e..9faa70a0b8 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -56,7 +56,7 @@ from models.provider_ids import ModelProviderID from services.feature_service import FeatureService if TYPE_CHECKING: - from graphon.model_runtime.runtime import ModelRuntime + from graphon.model_runtime.protocols.runtime import ModelRuntime _credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any]) @@ -165,7 +165,7 @@ class ProviderManager: ) # Get all provider entities - model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) + model_provider_factory = ModelProviderFactory(runtime=self._model_runtime) provider_entities = model_provider_factory.get_providers() # Get All preferred provider types of the workspace @@ -362,7 +362,7 @@ class ProviderManager: if not default_model: return None - model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) + model_provider_factory = ModelProviderFactory(runtime=self._model_runtime) provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name) return DefaultModelEntity( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index d2e2a16252..b18905fe8c 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -21,7 +21,7 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.tools.signature import sign_upload_file +from core.tools.signature import sign_upload_file_preview_url from extensions.ext_database import db from graphon.model_runtime.entities.model_entities import ModelType from models.dataset import ( @@ -865,7 +865,7 @@ class RetrievalService: "name": upload_file.name, "extension": "." + upload_file.extension, "mime_type": upload_file.mime_type, - "source_url": sign_upload_file(upload_file.id, upload_file.extension), + "source_url": sign_upload_file_preview_url(upload_file.id, upload_file.extension), "size": upload_file.size, } return {"attachment_info": attachment_info, "segment_id": attachment_binding.segment_id} @@ -892,7 +892,7 @@ class RetrievalService: "name": upload_file.name, "extension": "." + upload_file.extension, "mime_type": upload_file.mime_type, - "source_url": sign_upload_file(upload_file.id, upload_file.extension), + "source_url": sign_upload_file_preview_url(upload_file.id, upload_file.extension), "size": upload_file.size, } if attachment_binding: diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 02f0efc908..25f6fe3e2a 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -115,7 +115,7 @@ class PdfExtractor(BaseExtractor): """ image_content = [] upload_files = [] - base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + base_url = dify_config.FILES_URL try: image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,)) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 0330a43b28..60f8906181 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -110,7 +110,7 @@ class WordExtractor(BaseExtractor): def _extract_images_from_docx(self, doc): image_count = 0 image_map = {} - base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + base_url = dify_config.FILES_URL for r_id, rel in doc.part.rels.items(): if "image" in rel.target_ref: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index ba277d5018..a26a900512 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -29,6 +29,7 @@ from libs import helper from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import ProcessRuleMode from services.account_service import AccountService from services.summary_index_service import SummaryIndexService @@ -325,7 +326,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): # update document parent mode dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode="hierarchical", + mode=ProcessRuleMode.HIERARCHICAL, rules=json.dumps( { "parent_mode": parent_childs.parent_mode, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 5631b3a921..010566d203 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -52,7 +52,7 @@ from core.rag.retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) -from core.tools.signature import sign_upload_file +from core.tools.signature import sign_upload_file_preview_url from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.workflow.file_reference import build_file_reference from core.workflow.nodes.knowledge_retrieval import exc @@ -529,7 +529,7 @@ class DatasetRetrieval: ), size=upload_file.size, storage_key=upload_file.key, - url=sign_upload_file(upload_file.id, upload_file.extension), + url=sign_upload_file_preview_url(upload_file.id, upload_file.extension), ) context_files.append(attachment_info) if show_retrieve_source: diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 1807226924..ca4756f2a4 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -8,6 +8,10 @@ import urllib.parse from configs import dify_config +def _secret_key() -> bytes: + return dify_config.SECRET_KEY.encode() + + def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str: """ sign file to get a temporary url for plugin access @@ -19,26 +23,26 @@ def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) timestamp = str(int(time.time())) nonce = os.urandom(16).hex() data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" -def sign_upload_file(upload_file_id: str, extension: str) -> str: +def sign_upload_file_preview_url(upload_file_id: str, extension: str) -> str: """ - sign file to get a temporary url for plugin access + Sign an upload file to get a temporary image preview URL. + + The URL generated by this function is only for external preview and download, + not for internal communication. """ - # Use internal URL for plugin/tool file access in Docker environments - base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + base_url = dify_config.FILES_URL file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" @@ -49,8 +53,7 @@ def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: s verify signature """ data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() # verify signature @@ -69,8 +72,7 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, timestamp = str(int(time.time())) nonce = os.urandom(16).hex() data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() query = urllib.parse.urlencode( { @@ -90,8 +92,7 @@ def verify_plugin_file_signature( """Verify the signature used by the plugin-facing file upload endpoint.""" data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() if sign != recalculated_encoded_sign: diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index c87e8a3ae0..f2552e7cbd 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -51,8 +51,11 @@ class ToolFileManager: timestamp = str(int(time.time())) nonce = os.urandom(16).hex() data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + sign = hmac.new( + dify_config.SECRET_KEY.encode(), + data_to_sign.encode(), + hashlib.sha256, + ).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" @@ -63,8 +66,11 @@ class ToolFileManager: verify signature """ data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_sign = hmac.new( + dify_config.SECRET_KEY.encode(), + data_to_sign.encode(), + hashlib.sha256, + ).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() # verify signature diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index cd8c6352b5..3fbd456fe5 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -9,6 +9,7 @@ from sqlalchemy import select from core.app.file_access import DatabaseFileAccessController from core.db.session_factory import session_factory +from core.helper.trace_id_helper import ParentTraceContext, extract_parent_trace_context_from_args from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( @@ -36,6 +37,8 @@ class WorkflowTool(Tool): Workflow tool. """ + _parent_trace_context: ParentTraceContext | None + def __init__( self, workflow_app_id: str, @@ -54,6 +57,7 @@ class WorkflowTool(Tool): self.workflow_call_depth = workflow_call_depth self.label = label self._latest_usage = LLMUsage.empty_usage() + self._parent_trace_context = None super().__init__(entity=entity, runtime=runtime) @@ -94,11 +98,17 @@ class WorkflowTool(Tool): self._latest_usage = LLMUsage.empty_usage() + generator_args: dict[str, Any] = {"inputs": tool_parameters, "files": files} + if self._parent_trace_context: + generator_args.update( + extract_parent_trace_context_from_args({"parent_trace_context": self._parent_trace_context}) + ) + result = generator.generate( app_model=app, workflow=workflow, user=user, - args={"inputs": tool_parameters, "files": files}, + args=generator_args, invoke_from=self.runtime.invoke_from, streaming=False, call_depth=self.workflow_call_depth + 1, @@ -194,7 +204,7 @@ class WorkflowTool(Tool): :return: the new tool """ - return self.__class__( + forked = self.__class__( entity=self.entity.model_copy(), runtime=runtime, workflow_app_id=self.workflow_app_id, @@ -204,6 +214,24 @@ class WorkflowTool(Tool): version=self.version, label=self.label, ) + forked._parent_trace_context = self._parent_trace_context.model_copy() if self._parent_trace_context else None + return forked + + def set_parent_trace_context( + self, + *, + parent_workflow_run_id: str, + parent_node_execution_id: str, + ) -> None: + """Attach outer workflow trace context without exposing it as tool input.""" + self._parent_trace_context = ParentTraceContext( + parent_workflow_run_id=parent_workflow_run_id, + parent_node_execution_id=parent_node_execution_id, + ) + + def clear_parent_trace_context(self) -> None: + """Remove parent trace context before invoking this tool outside a nested workflow.""" + self._parent_trace_context = None def _resolve_user(self, user_id: str) -> Account | EndUser | None: """ diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 895953a3c1..a306b1c9ac 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -374,11 +374,6 @@ class DifyNodeFactory(NodeFactory): # Re-validate using the resolved node class so workflow-local node schemas # stay explicit and constructors receive the concrete typed payload. resolved_node_data = self._validate_resolved_node_data(node_class, node_data) - config_for_node_init: BaseNodeData | dict[str, Any] - if isinstance(resolved_node_data, BaseNodeData): - config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True) - else: - config_for_node_init = resolved_node_data node_type = node_data.type node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { BuiltinNodeTypes.CODE: lambda: { @@ -446,9 +441,10 @@ class DifyNodeFactory(NodeFactory): }, } node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() + constructor_node_data = resolved_node_data.model_dump(mode="python", by_alias=True) return node_class( node_id=node_id, - config=config_for_node_init, + data=constructor_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, **node_init_kwargs, diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py index 08125ff733..5e5bbc60cb 100644 --- a/api/core/workflow/node_runtime.py +++ b/api/core/workflow/node_runtime.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.file_access import DatabaseFileAccessController from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.helper.trace_id_helper import ParentTraceContext from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelInstance @@ -362,6 +363,7 @@ class _WorkflowToolRuntimeBinding: tool: Tool conversation_id: str | None = None + parent_trace_context: ParentTraceContext | None = None class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): @@ -382,6 +384,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): node_id: str, node_data: ToolNodeData, variable_pool, + node_execution_id: str | None = None, ) -> ToolRuntimeHandle: try: tool_runtime = ToolManager.get_workflow_tool_runtime( @@ -401,7 +404,25 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): conversation_id = ( None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) ) - return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id)) + parent_trace_context: ParentTraceContext | None = None + if self._is_workflow_tool_provider(node_data): + outer_workflow_run_id = ( + None + if variable_pool is None + else get_system_text(variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID) + ) + if isinstance(outer_workflow_run_id, str) and isinstance(node_execution_id, str): + parent_trace_context = ParentTraceContext( + parent_workflow_run_id=outer_workflow_run_id, + parent_node_execution_id=node_execution_id, + ) + return ToolRuntimeHandle( + raw=_WorkflowToolRuntimeBinding( + tool=tool_runtime, + conversation_id=conversation_id, + parent_trace_context=parent_trace_context, + ) + ) def get_runtime_parameters( self, @@ -425,6 +446,13 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): runtime_binding = self._binding_from_handle(tool_runtime) tool = runtime_binding.tool callback = DifyWorkflowCallbackHandler() + if runtime_binding.parent_trace_context and hasattr(tool, "set_parent_trace_context"): + tool.set_parent_trace_context( + parent_workflow_run_id=runtime_binding.parent_trace_context.parent_workflow_run_id, + parent_node_execution_id=runtime_binding.parent_trace_context.parent_node_execution_id, + ) + elif hasattr(tool, "clear_parent_trace_context"): + tool.clear_parent_trace_context() try: messages = ToolEngine.generic_invoke( @@ -517,6 +545,10 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): credential_id=node_data.credential_id, ) + @staticmethod + def _is_workflow_tool_provider(node_data: ToolNodeData) -> bool: + return node_data.provider_type.value == CoreToolProviderType.WORKFLOW.value + def _adapt_messages( self, messages: Generator[CoreToolInvokeMessage, None, None], diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 68a24e86b1..17d71668cb 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -35,7 +35,7 @@ class AgentNode(Node[AgentNodeData]): def __init__( self, node_id: str, - config: AgentNodeData, + data: AgentNodeData, *, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, @@ -46,7 +46,7 @@ class AgentNode(Node[AgentNodeData]): ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index f3006c4242..a4ef3d1ea7 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -36,14 +36,14 @@ class DatasourceNode(Node[DatasourceNodeData]): def __init__( self, node_id: str, - config: DatasourceNodeData, + data: DatasourceNodeData, *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 9c1b7ab2c4..1d60f530a1 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -32,14 +32,14 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def __init__( self, node_id: str, - config: KnowledgeIndexNodeData, + data: KnowledgeIndexNodeData, *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 25f73e446d..1aba2737b0 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -71,14 +71,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def __init__( self, node_id: str, - config: KnowledgeRetrievalNodeData, + data: KnowledgeRetrievalNodeData, *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) diff --git a/api/core/workflow/system_variables.py b/api/core/workflow/system_variables.py index 9d15a3fcea..77ef3826e9 100644 --- a/api/core/workflow/system_variables.py +++ b/api/core/workflow/system_variables.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import defaultdict from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any, Protocol, cast +from typing import Any, Protocol from uuid import uuid4 from graphon.enums import BuiltinNodeTypes @@ -82,13 +82,10 @@ def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: normalized = _normalize_system_variable_values(values, **kwargs) return [ - cast( - Variable, - segment_to_variable( - segment=build_segment(value), - selector=system_variable_selector(key), - name=key, - ), + segment_to_variable( + segment=build_segment(value), + selector=system_variable_selector(key), + name=key, ) for key, value in normalized.items() ] @@ -130,13 +127,10 @@ def build_bootstrap_variables( for node_id, value in rag_pipeline_variables_map.items(): variables.append( - cast( - Variable, - segment_to_variable( - segment=build_segment(value), - selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), - name=node_id, - ), + segment_to_variable( + segment=build_segment(value), + selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), + name=node_id, ) ) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 4e2f603e5b..3019704dac 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -46,6 +46,11 @@ _file_access_controller = DatabaseFileAccessController() class _WorkflowChildEngineBuilder: + tenant_id: str + + def __init__(self, *, tenant_id: str) -> None: + self.tenant_id = tenant_id + @staticmethod def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None: """ @@ -107,7 +112,7 @@ class _WorkflowChildEngineBuilder: config=config, child_engine_builder=self, ) - child_engine.layer(LLMQuotaLayer()) + child_engine.layer(LLMQuotaLayer(tenant_id=self.tenant_id)) return child_engine @@ -176,7 +181,7 @@ class WorkflowEntry: self.command_channel = command_channel execution_context = capture_current_context() graph_runtime_state.execution_context = execution_context - self._child_engine_builder = _WorkflowChildEngineBuilder() + self._child_engine_builder = _WorkflowChildEngineBuilder(tenant_id=tenant_id) self.graph_engine = GraphEngine( workflow_id=workflow_id, graph=graph, @@ -208,7 +213,7 @@ class WorkflowEntry: max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME ) self.graph_engine.layer(limits_layer) - self.graph_engine.layer(LLMQuotaLayer()) + self.graph_engine.layer(LLMQuotaLayer(tenant_id=tenant_id)) # Add observability layer when OTel is enabled if dify_config.ENABLE_OTEL or is_instrument_flag_enabled(): diff --git a/api/dev/generate_swagger_markdown_docs.py b/api/dev/generate_swagger_markdown_docs.py index 0900d08331..e0028c63f6 100644 --- a/api/dev/generate_swagger_markdown_docs.py +++ b/api/dev/generate_swagger_markdown_docs.py @@ -29,18 +29,39 @@ STALE_COMBINED_MARKDOWN_FILENAME = "api-reference.md" def _convert_spec_to_markdown(spec_path: Path, markdown_path: Path) -> None: - subprocess.run( - [ - "npx", - "--yes", - SWAGGER_MARKDOWN_PACKAGE, - "-i", - str(spec_path), - "-o", - str(markdown_path), - ], - check=True, - ) + markdown_path.parent.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory(prefix=f"{markdown_path.stem}-", dir=markdown_path.parent) as temp_dir: + temp_markdown_path = Path(temp_dir) / markdown_path.name + result = subprocess.run( + [ + "npx", + "--yes", + SWAGGER_MARKDOWN_PACKAGE, + "-i", + str(spec_path), + "-o", + str(temp_markdown_path), + ], + check=False, + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise subprocess.CalledProcessError( + result.returncode, + result.args, + output=result.stdout, + stderr=result.stderr, + ) + if not temp_markdown_path.exists(): + converter_output = "\n".join(item for item in (result.stdout, result.stderr) if item).strip() + raise RuntimeError(f"swagger-markdown did not write {markdown_path}: {converter_output}") + + converted_markdown = temp_markdown_path.read_text(encoding="utf-8") + if not converted_markdown.strip(): + raise RuntimeError(f"swagger-markdown wrote an empty document for {markdown_path}") + + markdown_path.write_text(converted_markdown, encoding="utf-8") def _demote_markdown_headings(markdown: str, *, levels: int = 1) -> str: diff --git a/api/dev/generate_swagger_specs.py b/api/dev/generate_swagger_specs.py index 9122f3ab24..254310cd2a 100644 --- a/api/dev/generate_swagger_specs.py +++ b/api/dev/generate_swagger_specs.py @@ -20,7 +20,6 @@ from pathlib import Path from typing import Protocol, TypeGuard from flask import Flask -from flask_restx.swagger import Swagger logger = logging.getLogger(__name__) @@ -48,9 +47,6 @@ SPEC_TARGETS: tuple[SpecTarget, ...] = ( SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"), ) -_ORIGINAL_REGISTER_MODEL = Swagger.register_model -_ORIGINAL_REGISTER_FIELD = Swagger.register_field - def _is_inline_field_map(value: object) -> TypeGuard[dict[object, object]]: """Return whether a nested field map is an anonymous inline mapping.""" @@ -152,56 +148,14 @@ def apply_runtime_defaults() -> None: dify_config.SWAGGER_UI_ENABLED = os.environ["SWAGGER_UI_ENABLED"].lower() == "true" -def _patch_swagger_for_inline_nested_dicts() -> None: - """Teach Flask-RESTX Swagger generation to tolerate inline nested field maps. - - Some existing controllers use `fields.Nested({...})` with a raw field mapping - instead of a named `api.model(...)`. Flask-RESTX crashes on those anonymous - dicts during schema registration, so this helper upgrades them into temporary - named models at export time. - """ - - if getattr(Swagger, "_dify_inline_nested_dict_patch", False): - return - - def get_or_create_inline_model(self: Swagger, nested_fields: dict[object, object]) -> object: - anonymous_models = getattr(self, "_anonymous_inline_models", None) - if anonymous_models is None: - anonymous_models = {} - self.__dict__["_anonymous_inline_models"] = anonymous_models - - anonymous_name = anonymous_models.get(id(nested_fields)) - if anonymous_name is None: - anonymous_name = _inline_model_name(nested_fields) - anonymous_models[id(nested_fields)] = anonymous_name - if anonymous_name not in self.api.models: - self.api.model(anonymous_name, nested_fields) - - return self.api.models[anonymous_name] - - def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]: - if _is_inline_field_map(model): - model = get_or_create_inline_model(self, model) - - return _ORIGINAL_REGISTER_MODEL(self, model) - - def register_field_with_inline_dict_support(self: Swagger, field: object) -> None: - nested = getattr(field, "nested", None) - if _is_inline_field_map(nested): - field.model = get_or_create_inline_model(self, nested) # type: ignore - - _ORIGINAL_REGISTER_FIELD(self, field) - - Swagger.register_model = register_model_with_inline_dict_support - Swagger.register_field = register_field_with_inline_dict_support - Swagger._dify_inline_nested_dict_patch = True - - def create_spec_app() -> Flask: """Build a minimal Flask app that only mounts the Swagger-producing blueprints.""" apply_runtime_defaults() - _patch_swagger_for_inline_nested_dicts() + + from libs.flask_restx_compat import patch_swagger_for_inline_nested_dicts + + patch_swagger_for_inline_nested_dicts() app = Flask(__name__) diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 1d615f0f87..8dec5876a9 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -137,17 +137,13 @@ def handle(sender: Message, **kwargs): if used_quota is not None: match provider_configuration.system_configuration.current_quota_type: case ProviderQuotaType.TRIAL: - from services.credit_pool_service import CreditPoolService - - CreditPoolService.check_and_deduct_credits( + _deduct_credit_pool_quota_capped( tenant_id=tenant_id, credits_required=used_quota, pool_type="trial", ) case ProviderQuotaType.PAID: - from services.credit_pool_service import CreditPoolService - - CreditPoolService.check_and_deduct_credits( + _deduct_credit_pool_quota_capped( tenant_id=tenant_id, credits_required=used_quota, pool_type="paid", @@ -200,6 +196,26 @@ def handle(sender: Message, **kwargs): raise +def _deduct_credit_pool_quota_capped(*, tenant_id: str, credits_required: int, pool_type: str) -> None: + """Apply post-generation credit accounting without failing message persistence on quota exhaustion.""" + from services.credit_pool_service import CreditPoolService + + deducted_credits = CreditPoolService.deduct_credits_capped( + tenant_id=tenant_id, + credits_required=credits_required, + pool_type=pool_type, + ) + if deducted_credits < credits_required: + logger.warning( + "Credit pool exhausted during message-created accounting, " + "tenant_id=%s, pool_type=%s, credits_required=%s, credits_deducted=%s", + tenant_id, + pool_type, + credits_required, + deducted_credits, + ) + + def _calculate_quota_usage( *, message: Message, system_configuration: SystemConfiguration, model_name: str ) -> int | None: diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py index 4a6490b9f0..914baaadaf 100644 --- a/api/extensions/ext_app_metrics.py +++ b/api/extensions/ext_app_metrics.py @@ -5,6 +5,7 @@ import threading from flask import Response from configs import dify_config +from controllers.console.admin import admin_required from dify_app import DifyApp @@ -25,6 +26,7 @@ def init_app(app: DifyApp): ) @app.route("/threads") + @admin_required def threads(): # pyright: ignore[reportUnusedFunction] num_threads = threading.active_count() threads = threading.enumerate() @@ -50,6 +52,7 @@ def init_app(app: DifyApp): } @app.route("/db-pool-stat") + @admin_required def pool_stat(): # pyright: ignore[reportUnusedFunction] from extensions.ext_database import db diff --git a/api/extensions/ext_set_secretkey.py b/api/extensions/ext_set_secretkey.py index dfb87c0167..ca59a2de4d 100644 --- a/api/extensions/ext_set_secretkey.py +++ b/api/extensions/ext_set_secretkey.py @@ -1,6 +1,13 @@ from configs import dify_config +from configs.secret_key import resolve_secret_key from dify_app import DifyApp -def init_app(app: DifyApp): - app.secret_key = dify_config.SECRET_KEY +def init_app(app: DifyApp) -> None: + """Resolve SECRET_KEY after config loading and before session/login setup.""" + secret_key = dify_config.SECRET_KEY + if not secret_key: + secret_key = resolve_secret_key(secret_key) + dify_config.SECRET_KEY = secret_key + app.config["SECRET_KEY"] = secret_key + app.secret_key = secret_key diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py index 4b3d514238..27441bdcc1 100644 --- a/api/factories/file_factory/message_files.py +++ b/api/factories/file_factory/message_files.py @@ -1,11 +1,18 @@ -"""Adapters from persisted message files to graph-layer file values.""" +"""Adapters from persisted message files to graph-layer file values. + +Replay paths only: files in conversation history were validated at upload time, +so these helpers deliberately do not accept (or forward) a ``FileUploadConfig`` — +re-validation here would break replays whenever workflow ``file_upload`` config +drifts between rounds. Mirrors ``build_file_from_stored_mapping`` in +``models/utils/file_input_compat.py``. +""" from __future__ import annotations from collections.abc import Sequence from core.app.file_access import FileAccessControllerProtocol -from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig +from graphon.file import File, FileBelongsTo, FileTransferMethod from models import MessageFile from .builders import build_from_mapping @@ -15,14 +22,12 @@ def build_from_message_files( *, message_files: Sequence[MessageFile], tenant_id: str, - config: FileUploadConfig | None = None, access_controller: FileAccessControllerProtocol, ) -> Sequence[File]: return [ build_from_message_file( message_file=message_file, tenant_id=tenant_id, - config=config, access_controller=access_controller, ) for message_file in message_files @@ -34,7 +39,6 @@ def build_from_message_file( *, message_file: MessageFile, tenant_id: str, - config: FileUploadConfig | None, access_controller: FileAccessControllerProtocol, ) -> File: mapping = { @@ -54,6 +58,5 @@ def build_from_message_file( return build_from_mapping( mapping=mapping, tenant_id=tenant_id, - config=config, access_controller=access_controller, ) diff --git a/api/factories/file_factory/validation.py b/api/factories/file_factory/validation.py index 4c4f6150e4..8c4e7ef1d4 100644 --- a/api/factories/file_factory/validation.py +++ b/api/factories/file_factory/validation.py @@ -2,9 +2,25 @@ from __future__ import annotations +from collections.abc import Iterable + from graphon.file import FileTransferMethod, FileType, FileUploadConfig +def _normalize_extension(extension: str) -> str: + s = extension.strip().lower() + if not s: + return "" + return s if s.startswith(".") else "." + s + + +def _extension_matches(extension: str, whitelist: Iterable[str]) -> bool: + normalized = _normalize_extension(extension) + if not normalized: + return False + return normalized in {_normalize_extension(e) for e in whitelist} + + def is_file_valid_with_config( *, input_file_type: str, @@ -12,22 +28,31 @@ def is_file_valid_with_config( file_transfer_method: FileTransferMethod, config: FileUploadConfig, ) -> bool: - # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) - # These are internally generated and should bypass user upload restrictions + """Return whether the file is allowed by the upload config. + + ``allowed_file_types`` lists the buckets a file may fall into; ``CUSTOM`` is + a fallback bucket gated by ``allowed_file_extensions`` (case- and + dot-insensitive). Tool-generated files bypass user-facing config. + """ if file_transfer_method == FileTransferMethod.TOOL_FILE: return True - if ( - config.allowed_file_types - and input_file_type not in config.allowed_file_types - and input_file_type != FileType.CUSTOM - ): + allowed_types = config.allowed_file_types or [] + custom_allowed = FileType.CUSTOM in allowed_types + type_allowed = not allowed_types or input_file_type in allowed_types + + if not type_allowed and not custom_allowed: return False + # When the file is in the CUSTOM bucket, the extension whitelist is authoritative. + # An explicitly set whitelist (including the empty list) is enforced; empty == deny — + # the UI never submits an empty list, so this guards against DSL/API paths that + # bypass the UI from accidentally widening the allowlist. + in_custom_bucket = input_file_type == FileType.CUSTOM or not type_allowed if ( - input_file_type == FileType.CUSTOM + in_custom_bucket and config.allowed_file_extensions is not None - and file_extension not in config.allowed_file_extensions + and not _extension_matches(file_extension, config.allowed_file_extensions) ): return False diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 8c659086ed..a852f21bb2 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -1,14 +1,21 @@ +"""Workflow run response schemas for console APIs. + +Most workflow-run endpoints should document and serialize responses with the +Pydantic models in this module. The remaining Flask-RESTX field dictionaries are +kept only for workflow app-log endpoints that still build legacy log models. +""" + from __future__ import annotations from datetime import datetime from typing import Any from flask_restx import Namespace, fields -from pydantic import Field, field_validator +from pydantic import AliasChoices, Field, field_validator from fields.base import ResponseModel -from fields.end_user_fields import SimpleEndUser, simple_end_user_fields -from fields.member_fields import SimpleAccount, simple_account_fields +from fields.end_user_fields import SimpleEndUser +from fields.member_fields import SimpleAccount from libs.helper import TimestampField workflow_run_for_log_fields = { @@ -43,119 +50,6 @@ def build_workflow_run_for_archived_log_model(api_or_ns: Namespace): return api_or_ns.model("WorkflowRunForArchivedLog", workflow_run_for_archived_log_fields) -workflow_run_for_list_fields = { - "id": fields.String, - "version": fields.String, - "status": fields.String, - "elapsed_time": fields.Float, - "total_tokens": fields.Integer, - "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), - "created_at": TimestampField, - "finished_at": TimestampField, - "exceptions_count": fields.Integer, - "retry_index": fields.Integer, -} - -advanced_chat_workflow_run_for_list_fields = { - "id": fields.String, - "conversation_id": fields.String, - "message_id": fields.String, - "version": fields.String, - "status": fields.String, - "elapsed_time": fields.Float, - "total_tokens": fields.Integer, - "total_steps": fields.Integer, - "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), - "created_at": TimestampField, - "finished_at": TimestampField, - "exceptions_count": fields.Integer, - "retry_index": fields.Integer, -} - -advanced_chat_workflow_run_pagination_fields = { - "limit": fields.Integer(attribute="limit"), - "has_more": fields.Boolean(attribute="has_more"), - "data": fields.List(fields.Nested(advanced_chat_workflow_run_for_list_fields), attribute="data"), -} - -workflow_run_pagination_fields = { - "limit": fields.Integer(attribute="limit"), - "has_more": fields.Boolean(attribute="has_more"), - "data": fields.List(fields.Nested(workflow_run_for_list_fields), attribute="data"), -} - -workflow_run_count_fields = { - "total": fields.Integer, - "running": fields.Integer, - "succeeded": fields.Integer, - "failed": fields.Integer, - "stopped": fields.Integer, - "partial_succeeded": fields.Integer(attribute="partial-succeeded"), -} - -workflow_run_detail_fields = { - "id": fields.String, - "version": fields.String, - "graph": fields.Raw(attribute="graph_dict"), - "inputs": fields.Raw(attribute="inputs_dict"), - "status": fields.String, - "outputs": fields.Raw(attribute="outputs_dict"), - "error": fields.String, - "elapsed_time": fields.Float, - "total_tokens": fields.Integer, - "total_steps": fields.Integer, - "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), - "created_at": TimestampField, - "finished_at": TimestampField, - "exceptions_count": fields.Integer, -} - -retry_event_field = { - "elapsed_time": fields.Float, - "status": fields.String, - "inputs": fields.Raw(attribute="inputs"), - "process_data": fields.Raw(attribute="process_data"), - "outputs": fields.Raw(attribute="outputs"), - "metadata": fields.Raw(attribute="metadata"), - "llm_usage": fields.Raw(attribute="llm_usage"), - "error": fields.String, - "retry_index": fields.Integer, -} - - -workflow_run_node_execution_fields = { - "id": fields.String, - "index": fields.Integer, - "predecessor_node_id": fields.String, - "node_id": fields.String, - "node_type": fields.String, - "title": fields.String, - "inputs": fields.Raw(attribute="inputs_dict"), - "process_data": fields.Raw(attribute="process_data_dict"), - "outputs": fields.Raw(attribute="outputs_dict"), - "status": fields.String, - "error": fields.String, - "elapsed_time": fields.Float, - "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), - "extras": fields.Raw, - "created_at": TimestampField, - "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), - "finished_at": TimestampField, - "inputs_truncated": fields.Boolean, - "outputs_truncated": fields.Boolean, - "process_data_truncated": fields.Boolean, -} - -workflow_run_node_execution_list_fields = { - "data": fields.List(fields.Nested(workflow_run_node_execution_fields)), -} - - def _to_timestamp(value: datetime | int | None) -> int | None: if isinstance(value, datetime): return int(value.timestamp()) @@ -252,7 +146,10 @@ class WorkflowRunCountResponse(ResponseModel): succeeded: int failed: int stopped: int - partial_succeeded: int = Field(validation_alias="partial-succeeded") + partial_succeeded: int = Field( + alias="partial_succeeded", + validation_alias=AliasChoices("partial_succeeded", "partial-succeeded"), + ) class WorkflowRunDetailResponse(ResponseModel): diff --git a/api/libs/external_api.py b/api/libs/external_api.py index f907d17750..64eb99a42b 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -9,6 +9,7 @@ from werkzeug.http import HTTP_STATUS_CODES from configs import dify_config from core.errors.error import AppInvokeQuotaExceededError +from libs.flask_restx_compat import patch_swagger_for_inline_nested_dicts from libs.token import build_force_logout_cookie_headers @@ -120,6 +121,7 @@ class ExternalApi(Api): } def __init__(self, app: Blueprint | Flask, *args, **kwargs): + patch_swagger_for_inline_nested_dicts() kwargs.setdefault("authorizations", self._authorizations) kwargs.setdefault("security", "Bearer") kwargs["add_specs"] = dify_config.SWAGGER_UI_ENABLED diff --git a/api/libs/flask_restx_compat.py b/api/libs/flask_restx_compat.py new file mode 100644 index 0000000000..34e0d586a0 --- /dev/null +++ b/api/libs/flask_restx_compat.py @@ -0,0 +1,149 @@ +"""Compatibility helpers for Dify's Flask-RESTX Swagger integration. + +These helpers are temporary bridges for legacy Flask-RESTX field contracts +while controllers migrate their request and response documentation to Pydantic +models. Keep the behavior centralized so live Swagger endpoints and offline +spec export fail or succeed in the same way. +""" + +import hashlib +import json +from typing import TypeGuard + +from flask import current_app +from flask_restx import fields +from flask_restx.model import Model, OrderedModel, instance +from flask_restx.swagger import Swagger + + +def _is_inline_field_map(value: object) -> TypeGuard[dict[object, object]]: + """Return whether a nested field map is an anonymous inline mapping.""" + + return isinstance(value, dict) and not isinstance(value, (Model, OrderedModel)) + + +def _jsonable_schema_value(value: object) -> object: + """Return a deterministic JSON-serializable representation for schema fingerprints.""" + + if value is None or isinstance(value, str | int | float | bool): + return value + if isinstance(value, list | tuple): + return [_jsonable_schema_value(item) for item in value] + if isinstance(value, dict): + return {str(key): _jsonable_schema_value(item) for key, item in value.items()} + value_type = type(value) + return f"<{value_type.__module__}.{value_type.__qualname__}>" + + +def _field_signature(field: object) -> object: + """Build a stable signature for a Flask-RESTX field object.""" + + field_instance = instance(field) + signature: dict[str, object] = { + "class": f"{field_instance.__class__.__module__}.{field_instance.__class__.__qualname__}" + } + + if isinstance(field_instance, fields.Nested): + nested = getattr(field_instance, "nested", None) + if _is_inline_field_map(nested): + signature["nested"] = _inline_model_signature(nested) + else: + signature["nested"] = getattr( + nested, + "name", + f"<{type(nested).__module__}.{type(nested).__qualname__}>", + ) + elif hasattr(field_instance, "container"): + signature["container"] = _field_signature(field_instance.container) + else: + schema = getattr(field_instance, "__schema__", None) + if isinstance(schema, dict): + signature["schema"] = _jsonable_schema_value(schema) + + for attr_name in ( + "attribute", + "default", + "description", + "example", + "max", + "max_items", + "min", + "min_items", + "nullable", + "readonly", + "required", + "title", + "unique", + ): + if hasattr(field_instance, attr_name): + signature[attr_name] = _jsonable_schema_value(getattr(field_instance, attr_name)) + + return signature + + +def _inline_model_signature(nested_fields: dict[object, object]) -> object: + """Build a stable signature for an anonymous inline model.""" + + return [ + (str(field_name), _field_signature(field)) + for field_name, field in sorted(nested_fields.items(), key=lambda item: str(item[0])) + ] + + +def _inline_model_name(nested_fields: dict[object, object]) -> str: + """Return a stable Swagger model name for an anonymous inline field map.""" + + signature = json.dumps(_inline_model_signature(nested_fields), sort_keys=True, separators=(",", ":")) + digest = hashlib.sha1(signature.encode("utf-8")).hexdigest()[:12] + return f"_AnonymousInlineModel_{digest}" + + +def patch_swagger_for_inline_nested_dicts() -> None: + """Allow Swagger generation to handle legacy inline Flask-RESTX field dicts. + + Some existing controllers use raw field mappings in `fields.Nested({...})` + or directly in `@namespace.response(...)`. Runtime marshalling accepts that, + but Flask-RESTX Swagger registration expects a named model. Convert those + anonymous mappings into temporary named models during docs generation. + """ + + if getattr(Swagger, "_dify_inline_nested_dict_patch", False): + return + + original_register_model = Swagger.register_model + original_register_field = Swagger.register_field + original_as_dict = Swagger.as_dict + + def get_or_create_inline_model(self: Swagger, nested_fields: dict[object, object]) -> object: + anonymous_name = _inline_model_name(nested_fields) + if anonymous_name not in self.api.models: + self.api.model(anonymous_name, nested_fields) + + return self.api.models[anonymous_name] + + def register_model_with_inline_dict_support(self: Swagger, model: object) -> dict[str, str]: + if _is_inline_field_map(model): + model = get_or_create_inline_model(self, model) + + return original_register_model(self, model) + + def register_field_with_inline_dict_support(self: Swagger, field: object) -> None: + nested = getattr(field, "nested", None) + if _is_inline_field_map(nested): + field.model = get_or_create_inline_model(self, nested) # type: ignore[attr-defined] + + original_register_field(self, field) + + def as_dict_with_inline_dict_support(self: Swagger): + # Temporary set RESTX_INCLUDE_ALL_MODELS = false to prevent "length changed while iterating" error + include_all_models = current_app.config.get("RESTX_INCLUDE_ALL_MODELS", False) + current_app.config["RESTX_INCLUDE_ALL_MODELS"] = False + try: + return original_as_dict(self) + finally: + current_app.config["RESTX_INCLUDE_ALL_MODELS"] = include_all_models + + Swagger.register_model = register_model_with_inline_dict_support + Swagger.register_field = register_field_with_inline_dict_support + Swagger.as_dict = as_dict_with_inline_dict_support + Swagger._dify_inline_nested_dict_patch = True diff --git a/api/models/comment.py b/api/models/comment.py index 277287132d..c2086c95b5 100644 --- a/api/models/comment.py +++ b/api/models/comment.py @@ -1,19 +1,22 @@ """Workflow comment models.""" +from __future__ import annotations + from datetime import datetime -from typing import Optional import sqlalchemy as sa from sqlalchemy import Index, func from sqlalchemy.orm import Mapped, mapped_column, relationship +from models.base import TypeBase + from .account import Account -from .base import Base, gen_uuidv7_string +from .base import gen_uuidv7_string from .engine import db from .types import StringUUID -class WorkflowComment(Base): +class WorkflowComment(TypeBase): """Workflow comment model for canvas commenting functionality. Comments are associated with apps rather than specific workflow versions, @@ -42,7 +45,7 @@ class WorkflowComment(Base): Index("workflow_comments_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position_x: Mapped[float] = mapped_column(sa.Float) @@ -50,21 +53,27 @@ class WorkflowComment(Base): content: Mapped[str] = mapped_column(sa.Text, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp()) + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) resolved: Mapped[bool] = mapped_column( sa.Boolean, nullable=False, server_default=sa.text("false")) - resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime) - resolved_by: Mapped[str | None] = mapped_column(StringUUID) + resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime, default=None) + resolved_by: Mapped[str | None] = mapped_column(StringUUID, default=None) + resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) # Relationships - replies: Mapped[list["WorkflowCommentReply"]] = relationship( - "WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan" + replies: Mapped[list[WorkflowCommentReply]] = relationship( + lambda: WorkflowCommentReply, back_populates="comment", cascade="all, delete-orphan", init=False ) - mentions: Mapped[list["WorkflowCommentMention"]] = relationship( - "WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan" + mentions: Mapped[list[WorkflowCommentMention]] = relationship( + lambda: WorkflowCommentMention, back_populates="comment", cascade="all, delete-orphan", init=False ) @property @@ -133,7 +142,7 @@ class WorkflowComment(Base): return participants -class WorkflowCommentReply(Base): +class WorkflowCommentReply(TypeBase): """Workflow comment reply model. Attributes: @@ -151,20 +160,24 @@ class WorkflowCommentReply(Base): Index("comment_replies_created_at_idx", "created_at"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False) comment_id: Mapped[str] = mapped_column( StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) content: Mapped[str] = mapped_column(sa.Text, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp()) + sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + sa.DateTime, + nullable=False, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + init=False, ) # Relationships - comment: Mapped["WorkflowComment"] = relationship( - "WorkflowComment", back_populates="replies") + comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="replies", init=False) @property def created_by_account(self): @@ -178,7 +191,7 @@ class WorkflowCommentReply(Base): self._created_by_account_cache = account -class WorkflowCommentMention(Base): +class WorkflowCommentMention(TypeBase): """Workflow comment mention model. Mentions are only for internal accounts since end users @@ -198,19 +211,18 @@ class WorkflowCommentMention(Base): Index("comment_mentions_user_idx", "mentioned_user_id"), ) - id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string) + id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False) comment_id: Mapped[str] = mapped_column( StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False ) - reply_id: Mapped[str | None] = mapped_column( - StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True - ) mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + reply_id: Mapped[str | None] = mapped_column( + StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True, default=None + ) + # Relationships - comment: Mapped["WorkflowComment"] = relationship( - "WorkflowComment", back_populates="mentions") - reply: Mapped[Optional["WorkflowCommentReply"] - ] = relationship("WorkflowCommentReply") + comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="mentions", init=False) + reply: Mapped[WorkflowCommentReply | None] = relationship(lambda: WorkflowCommentReply, init=False) @property def mentioned_user_account(self): diff --git a/api/models/dataset.py b/api/models/dataset.py index 2b1a8ab228..6d87f8f03c 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -11,7 +11,7 @@ import time from collections.abc import Sequence from datetime import datetime from json import JSONDecodeError -from typing import Any, TypedDict, cast +from typing import Any, ClassVar, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa @@ -24,7 +24,7 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField, Metad from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.tools.signature import sign_upload_file +from core.tools.signature import sign_upload_file_preview_url from extensions.ext_storage import storage from libs.uuid_utils import uuidv7 @@ -441,23 +441,27 @@ class Dataset(Base): return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node" -class DatasetProcessRule(Base): # bug +class DatasetProcessRule(TypeBase): __tablename__ = "dataset_process_rules" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) - dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'")) - rules = mapped_column(LongText, nullable=True) - created_by = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + mode: Mapped[ProcessRuleMode] = mapped_column( + EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'") + ) + rules: Mapped[str | None] = mapped_column(LongText, nullable=True) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, server_default=func.current_timestamp(), init=False + ) MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] - AUTOMATIC_RULES: AutomaticRulesConfig = { + AUTOMATIC_RULES: ClassVar[AutomaticRulesConfig] = { "pre_processing_rules": [ {"id": "remove_extra_spaces", "enabled": True}, {"id": "remove_urls_emails", "enabled": False}, @@ -941,7 +945,7 @@ class DocumentSegment(Base): nonce = os.urandom(16).hex() timestamp = str(int(time.time())) data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + secret_key = dify_config.SECRET_KEY.encode() sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() @@ -958,7 +962,7 @@ class DocumentSegment(Base): nonce = os.urandom(16).hex() timestamp = str(int(time.time())) data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + secret_key = dify_config.SECRET_KEY.encode() sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() @@ -977,7 +981,7 @@ class DocumentSegment(Base): nonce = os.urandom(16).hex() timestamp = str(int(time.time())) data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + secret_key = dify_config.SECRET_KEY.encode() sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() @@ -1015,12 +1019,12 @@ class DocumentSegment(Base): nonce = os.urandom(16).hex() timestamp = str(int(time.time())) data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + secret_key = dify_config.SECRET_KEY.encode() sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" - reference_url = dify_config.CONSOLE_API_URL or "" + reference_url = dify_config.FILES_URL or dify_config.CONSOLE_API_URL or "" base_url = f"{reference_url}/files/{upload_file_id}/image-preview" source_url = f"{base_url}?{params}" attachment_list.append( @@ -1162,7 +1166,7 @@ class DatasetQuery(TypeBase): "size": file_info.size, "extension": file_info.extension, "mime_type": file_info.mime_type, - "source_url": sign_upload_file(file_info.id, file_info.extension), + "source_url": sign_upload_file_preview_url(file_info.id, file_info.extension), } else: query["file_info"] = None diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index f4897e93c5..e56d5f6fe5 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -805,18 +805,17 @@ Get advanced chat workflow run list | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | -| payload | body | | Yes | [WorkflowRunListQuery](#workflowrunlistquery) | | app_id | path | Application ID | Yes | string | | last_id | query | Last run ID for pagination | No | string | -| limit | query | Number of items per page (1-100) | No | string | -| status | query | Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded | No | string | -| triggered_from | query | Filter by trigger source (optional): debugging or app-run. Default: debugging | No | string | +| limit | query | Number of items per page (1-100) | No | integer | +| status | query | Workflow run status filter | No | string | +| triggered_from | query | Filter by trigger source: debugging or app-run. Default: debugging | No | string | ##### Responses | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Workflow runs retrieved successfully | [AdvancedChatWorkflowRunPagination](#advancedchatworkflowrunpagination) | +| 200 | Workflow runs retrieved successfully | [AdvancedChatWorkflowRunPaginationResponse](#advancedchatworkflowrunpaginationresponse) | ### /apps/{app_id}/advanced-chat/workflow-runs/count @@ -833,17 +832,16 @@ Get advanced chat workflow runs count statistics | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | -| payload | body | | Yes | [WorkflowRunCountQuery](#workflowruncountquery) | | app_id | path | Application ID | Yes | string | -| status | query | Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded | No | string | +| status | query | Workflow run status filter | No | string | | time_range | query | Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), 30m (30 minutes), 30s (30 seconds). Filters by created_at field. | No | string | -| triggered_from | query | Filter by trigger source (optional): debugging or app-run. Default: debugging | No | string | +| triggered_from | query | Filter by trigger source: debugging or app-run. Default: debugging | No | string | ##### Responses | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Workflow runs count retrieved successfully | [WorkflowRunCount](#workflowruncount) | +| 200 | Workflow runs count retrieved successfully | [WorkflowRunCountResponse](#workflowruncountresponse) | ### /apps/{app_id}/advanced-chat/workflows/draft/human-input/nodes/{node_id}/form/preview @@ -2361,18 +2359,17 @@ Get workflow run list | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | -| payload | body | | Yes | [WorkflowRunListQuery](#workflowrunlistquery) | | app_id | path | Application ID | Yes | string | | last_id | query | Last run ID for pagination | No | string | -| limit | query | Number of items per page (1-100) | No | string | -| status | query | Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded | No | string | -| triggered_from | query | Filter by trigger source (optional): debugging or app-run. Default: debugging | No | string | +| limit | query | Number of items per page (1-100) | No | integer | +| status | query | Workflow run status filter | No | string | +| triggered_from | query | Filter by trigger source: debugging or app-run. Default: debugging | No | string | ##### Responses | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Workflow runs retrieved successfully | [WorkflowRunPagination](#workflowrunpagination) | +| 200 | Workflow runs retrieved successfully | [WorkflowRunPaginationResponse](#workflowrunpaginationresponse) | ### /apps/{app_id}/workflow-runs/count @@ -2389,17 +2386,16 @@ Get workflow runs count statistics | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | -| payload | body | | Yes | [WorkflowRunCountQuery](#workflowruncountquery) | | app_id | path | Application ID | Yes | string | -| status | query | Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded | No | string | +| status | query | Workflow run status filter | No | string | | time_range | query | Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), 30m (30 minutes), 30s (30 seconds). Filters by created_at field. | No | string | -| triggered_from | query | Filter by trigger source (optional): debugging or app-run. Default: debugging | No | string | +| triggered_from | query | Filter by trigger source: debugging or app-run. Default: debugging | No | string | ##### Responses | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Workflow runs count retrieved successfully | [WorkflowRunCount](#workflowruncount) | +| 200 | Workflow runs count retrieved successfully | [WorkflowRunCountResponse](#workflowruncountresponse) | ### /apps/{app_id}/workflow-runs/tasks/{task_id}/stop @@ -2449,7 +2445,7 @@ Get workflow run detail | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Workflow run detail retrieved successfully | [WorkflowRunDetail](#workflowrundetail) | +| 200 | Workflow run detail retrieved successfully | [WorkflowRunDetailResponse](#workflowrundetailresponse) | | 404 | Workflow run not found | | ### /apps/{app_id}/workflow-runs/{run_id}/export @@ -2470,7 +2466,7 @@ Generate a download URL for an archived workflow run. | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Export URL generated | [WorkflowRunExport](#workflowrunexport) | +| 200 | Export URL generated | [WorkflowRunExportResponse](#workflowrunexportresponse) | ### /apps/{app_id}/workflow-runs/{run_id}/node-executions @@ -2494,7 +2490,7 @@ Get workflow run node execution list | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Node executions retrieved successfully | [WorkflowRunNodeExecutionList](#workflowrunnodeexecutionlist) | +| 200 | Node executions retrieved successfully | [WorkflowRunNodeExecutionListResponse](#workflowrunnodeexecutionlistresponse) | | 404 | Workflow run not found | | ### /apps/{app_id}/workflow/comments @@ -3180,7 +3176,7 @@ Get last run result for draft workflow node | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Node last run retrieved successfully | [WorkflowRunNodeExecution](#workflowrunnodeexecution) | +| 200 | Node last run retrieved successfully | [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) | | 403 | Permission denied | | | 404 | Node last run not found | | @@ -3207,7 +3203,7 @@ Run draft workflow node | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Node run started successfully | [WorkflowRunNodeExecution](#workflowrunnodeexecution) | +| 200 | Node run started successfully | [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) | | 403 | Permission denied | | | 404 | Node not found | | @@ -6720,9 +6716,9 @@ Get workflow run list ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Workflow runs retrieved successfully | [WorkflowRunPaginationResponse](#workflowrunpaginationresponse) | ### /rag/pipelines/{pipeline_id}/workflow-runs/tasks/{task_id}/stop @@ -6760,9 +6756,9 @@ Get workflow run detail ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Workflow run detail retrieved successfully | [WorkflowRunDetailResponse](#workflowrundetailresponse) | ### /rag/pipelines/{pipeline_id}/workflow-runs/{run_id}/node-executions @@ -6780,9 +6776,9 @@ Get workflow run node execution list ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Node executions retrieved successfully | [WorkflowRunNodeExecutionListResponse](#workflowrunnodeexecutionlistresponse) | ### /rag/pipelines/{pipeline_id}/workflows @@ -6915,9 +6911,9 @@ Set datasource variables ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Datasource variables set successfully | [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) | ### /rag/pipelines/{pipeline_id}/workflows/draft/environment-variables @@ -6988,9 +6984,9 @@ Run draft workflow loop node ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Node last run retrieved successfully | [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) | ### /rag/pipelines/{pipeline_id}/workflows/draft/nodes/{node_id}/run @@ -7009,9 +7005,9 @@ Run draft workflow node ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Node run started successfully | [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) | ### /rag/pipelines/{pipeline_id}/workflows/draft/nodes/{node_id}/variables @@ -7947,6 +7943,7 @@ Get workflow pause details ##### Description +Get workflow pause details GET /console/api/workflow//pause-details Returns information about why and where the workflow is paused. @@ -7955,13 +7952,14 @@ Returns information about why and where the workflow is paused. | Name | Located in | Description | Required | Schema | | ---- | ---------- | ----------- | -------- | ------ | -| workflow_run_id | path | | Yes | string | +| workflow_run_id | path | Workflow run ID | Yes | string | ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Workflow pause details retrieved successfully | [WorkflowPauseDetailsResponse](#workflowpausedetailsresponse) | +| 404 | Workflow run not found | | ### /workspaces @@ -10256,31 +10254,31 @@ Get banner list | ---- | ---- | ----------- | -------- | | result | string | Operation result | Yes | -#### AdvancedChatWorkflowRunForList +#### AdvancedChatWorkflowRunForListResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| conversation_id | string | | No | -| created_at | object | | No | -| created_by_account | [SimpleAccount](#simpleaccount) | | No | -| elapsed_time | number | | No | -| exceptions_count | integer | | No | -| finished_at | object | | No | -| id | string | | No | -| message_id | string | | No | -| retry_index | integer | | No | -| status | string | | No | -| total_steps | integer | | No | -| total_tokens | integer | | No | -| version | string | | No | +| conversation_id | | | No | +| created_at | | | No | +| created_by_account | | | No | +| elapsed_time | | | No | +| exceptions_count | | | No | +| finished_at | | | No | +| id | string | | Yes | +| message_id | | | No | +| retry_index | | | No | +| status | | | No | +| total_steps | | | No | +| total_tokens | | | No | +| version | | | No | -#### AdvancedChatWorkflowRunPagination +#### AdvancedChatWorkflowRunPaginationResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| data | [ [AdvancedChatWorkflowRunForList](#advancedchatworkflowrunforlist) ] | | No | -| has_more | boolean | | No | -| limit | integer | | No | +| data | [ [AdvancedChatWorkflowRunForListResponse](#advancedchatworkflowrunforlistresponse) ] | | Yes | +| has_more | boolean | | Yes | +| limit | integer | | Yes | #### AdvancedChatWorkflowRunPayload @@ -12169,6 +12167,14 @@ Form input types. | form_inputs | object | Values the user provides for the form's own fields | Yes | | inputs | object | Values used to fill missing upstream variables referenced in form_content | Yes | +#### HumanInputPauseTypeResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| backstage_input_url | | | No | +| form_id | string | | Yes | +| type | string | | Yes | + #### IconType | Name | Type | Description | Required | @@ -13101,6 +13107,14 @@ Enum class for model type. | ---- | ---- | ----------- | -------- | | click_id | string | Click Id from partner referral link | Yes | +#### PausedNodeResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| node_id | string | | Yes | +| node_title | string | | Yes | +| pause_type | [HumanInputPauseTypeResponse](#humaninputpausetyperesponse) | | Yes | + #### Payload | Name | Type | Description | Required | @@ -13772,6 +13786,14 @@ Tag type | unit | string | | No | | variable | string | | No | +#### TrialSimpleAccount + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| email | string | | No | +| id | string | | No | +| name | string | | No | + #### TrialSite | Name | Type | Description | Required | @@ -13815,7 +13837,7 @@ Tag type | ---- | ---- | ----------- | -------- | | conversation_variables | [ [TrialConversationVariable](#trialconversationvariable) ] | | No | | created_at | object | | No | -| created_by | [SimpleAccount](#simpleaccount) | | No | +| created_by | [TrialSimpleAccount](#trialsimpleaccount) | | No | | environment_variables | [ object ] | | No | | features | object | | No | | graph | object | | No | @@ -13826,7 +13848,7 @@ Tag type | rag_pipeline_variables | [ [TrialPipelineVariable](#trialpipelinevariable) ] | | No | | tool_published | boolean | | No | | updated_at | object | | No | -| updated_by | [SimpleAccount](#simpleaccount) | | No | +| updated_by | [TrialSimpleAccount](#trialsimpleaccount) | | No | | version | string | | No | #### TrialWorkflowPartial @@ -14306,53 +14328,60 @@ User action configuration. | updated_at | | | No | | updated_by | | | No | -#### WorkflowRunCount +#### WorkflowPauseDetailsResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| failed | integer | | No | -| partial_succeeded | integer | | No | -| running | integer | | No | -| stopped | integer | | No | -| succeeded | integer | | No | -| total | integer | | No | +| paused_at | | | No | +| paused_nodes | [ [PausedNodeResponse](#pausednoderesponse) ] | | Yes | #### WorkflowRunCountQuery | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | | status | | Workflow run status filter | No | -| time_range | | Time range filter (e.g., 7d, 4h, 30m, 30s) | No | -| triggered_from | | Filter by trigger source: debugging or app-run | No | +| time_range | | Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), 30m (30 minutes), 30s (30 seconds). Filters by created_at field. | No | +| triggered_from | | Filter by trigger source: debugging or app-run. Default: debugging | No | -#### WorkflowRunDetail +#### WorkflowRunCountResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| created_at | object | | No | -| created_by_account | [SimpleAccount](#simpleaccount) | | No | -| created_by_end_user | [SimpleEndUser](#simpleenduser) | | No | -| created_by_role | string | | No | -| elapsed_time | number | | No | -| error | string | | No | -| exceptions_count | integer | | No | -| finished_at | object | | No | -| graph | object | | No | -| id | string | | No | -| inputs | object | | No | -| outputs | object | | No | -| status | string | | No | -| total_steps | integer | | No | -| total_tokens | integer | | No | -| version | string | | No | +| failed | integer | | Yes | +| partial_succeeded | integer | | Yes | +| running | integer | | Yes | +| stopped | integer | | Yes | +| succeeded | integer | | Yes | +| total | integer | | Yes | -#### WorkflowRunExport +#### WorkflowRunDetailResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| presigned_url | string | Pre-signed URL for download | No | -| presigned_url_expires_at | string | Pre-signed URL expiration time | No | -| status | string | Export status: success/failed | No | +| created_at | | | No | +| created_by_account | | | No | +| created_by_end_user | | | No | +| created_by_role | | | No | +| elapsed_time | | | No | +| error | | | No | +| exceptions_count | | | No | +| finished_at | | | No | +| graph | | | Yes | +| id | string | | Yes | +| inputs | | | Yes | +| outputs | | | Yes | +| status | | | No | +| total_steps | | | No | +| total_tokens | | | No | +| version | | | No | + +#### WorkflowRunExportResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| presigned_url | | Pre-signed URL for download | No | +| presigned_url_expires_at | | Pre-signed URL expiration time | No | +| status | string | Export status: success/failed | Yes | #### WorkflowRunForArchivedLogResponse @@ -14364,21 +14393,21 @@ User action configuration. | total_tokens | | | No | | triggered_from | | | No | -#### WorkflowRunForList +#### WorkflowRunForListResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| created_at | object | | No | -| created_by_account | [SimpleAccount](#simpleaccount) | | No | -| elapsed_time | number | | No | -| exceptions_count | integer | | No | -| finished_at | object | | No | -| id | string | | No | -| retry_index | integer | | No | -| status | string | | No | -| total_steps | integer | | No | -| total_tokens | integer | | No | -| version | string | | No | +| created_at | | | No | +| created_by_account | | | No | +| elapsed_time | | | No | +| exceptions_count | | | No | +| finished_at | | | No | +| id | string | | Yes | +| retry_index | | | No | +| status | | | No | +| total_steps | | | No | +| total_tokens | | | No | +| version | | | No | #### WorkflowRunForLogResponse @@ -14403,48 +14432,48 @@ User action configuration. | last_id | | Last run ID for pagination | No | | limit | integer | Number of items per page (1-100) | No | | status | | Workflow run status filter | No | -| triggered_from | | Filter by trigger source: debugging or app-run | No | +| triggered_from | | Filter by trigger source: debugging or app-run. Default: debugging | No | -#### WorkflowRunNodeExecution +#### WorkflowRunNodeExecutionListResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| created_at | object | | No | -| created_by_account | [SimpleAccount](#simpleaccount) | | No | -| created_by_end_user | [SimpleEndUser](#simpleenduser) | | No | -| created_by_role | string | | No | -| elapsed_time | number | | No | -| error | string | | No | -| execution_metadata | object | | No | -| extras | object | | No | -| finished_at | object | | No | -| id | string | | No | -| index | integer | | No | -| inputs | object | | No | -| inputs_truncated | boolean | | No | -| node_id | string | | No | -| node_type | string | | No | -| outputs | object | | No | -| outputs_truncated | boolean | | No | -| predecessor_node_id | string | | No | -| process_data | object | | No | -| process_data_truncated | boolean | | No | -| status | string | | No | -| title | string | | No | +| data | [ [WorkflowRunNodeExecutionResponse](#workflowrunnodeexecutionresponse) ] | | Yes | -#### WorkflowRunNodeExecutionList +#### WorkflowRunNodeExecutionResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| data | [ [WorkflowRunNodeExecution](#workflowrunnodeexecution) ] | | No | +| created_at | | | No | +| created_by_account | | | No | +| created_by_end_user | | | No | +| created_by_role | | | No | +| elapsed_time | | | No | +| error | | | No | +| execution_metadata | | | No | +| extras | | | No | +| finished_at | | | No | +| id | string | | Yes | +| index | | | No | +| inputs | | | No | +| inputs_truncated | | | No | +| node_id | | | No | +| node_type | | | No | +| outputs | | | No | +| outputs_truncated | | | No | +| predecessor_node_id | | | No | +| process_data | | | No | +| process_data_truncated | | | No | +| status | | | No | +| title | | | No | -#### WorkflowRunPagination +#### WorkflowRunPaginationResponse | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| data | [ [WorkflowRunForList](#workflowrunforlist) ] | | No | -| has_more | boolean | | No | -| limit | integer | | No | +| data | [ [WorkflowRunForListResponse](#workflowrunforlistresponse) ] | | Yes | +| has_more | boolean | | Yes | +| limit | integer | | Yes | #### WorkflowRunPayload diff --git a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py index 96df49ed0e..a0d150e1b6 100644 --- a/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/src/dify_trace_arize_phoenix/arize_phoenix_trace.py @@ -1,9 +1,11 @@ import json import logging import os +import re import traceback +from collections.abc import Mapping, Sequence from datetime import datetime, timedelta -from typing import Any, Union, cast +from typing import Any, Protocol, Union, cast from urllib.parse import urlparse from openinference.semconv.trace import ( @@ -19,7 +21,7 @@ from opentelemetry.sdk import trace as trace_sdk from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.semconv.attributes import exception_attributes -from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span +from opentelemetry.trace import Span, Status, StatusCode, get_current_span, set_span_in_context, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from opentelemetry.util.types import AttributeValue from sqlalchemy.orm import sessionmaker @@ -36,16 +38,106 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) +from core.ops.exceptions import PendingTraceParentContextError from core.ops.utils import JSON_DICT_ADAPTER from core.repositories import DifyCoreRepositoryFactory from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig from extensions.ext_database import db +from extensions.ext_redis import redis_client from graphon.enums import WorkflowNodeExecutionStatus from models.model import EndUser, MessageFile from models.workflow import WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) +# This parent-span carrier store is intentionally Phoenix-local for the current +# nested workflow tracing feature. If other trace providers need the same +# cross-task parent restoration behavior, move the storage and retry signaling +# behind a core trace coordination interface instead of duplicating it here. +_PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS = 300 +_TRACEPARENT_PATTERN = re.compile( + r"^(?P[0-9a-f]{2})-(?P[0-9a-f]{32})-(?P[0-9a-f]{16})-(?P[0-9a-f]{2})$" +) + + +def _phoenix_parent_span_redis_key(parent_node_execution_id: str) -> str: + """Build the Redis key that stores a restorable Phoenix parent span carrier.""" + return f"trace:phoenix:parent_span:{parent_node_execution_id}" + + +def _publish_parent_span_context(parent_node_execution_id: str, carrier: Mapping[str, str]) -> None: + """Persist a tracecontext carrier so nested workflow spans can restore the tool span parent.""" + redis_client.setex( + _phoenix_parent_span_redis_key(parent_node_execution_id), + _PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS, + safe_json_dumps(dict(carrier)), + ) + + +def _resolve_published_parent_span_context(parent_node_execution_id: str) -> dict[str, str]: + """Load a previously published tool-span carrier for nested workflow parenting.""" + raw_carrier = redis_client.get(_phoenix_parent_span_redis_key(parent_node_execution_id)) + if raw_carrier is None: + raise PendingTraceParentContextError(parent_node_execution_id) + + if isinstance(raw_carrier, bytes): + raw_carrier = raw_carrier.decode("utf-8") + + carrier = json.loads(raw_carrier) + if not isinstance(carrier, dict): + raise ValueError( + "Phoenix parent span context must be stored as a JSON object: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + normalized_carrier = {str(key): str(value) for key, value in carrier.items()} + if not normalized_carrier: + raise ValueError( + f"Phoenix parent span context payload is empty: parent_node_execution_id={parent_node_execution_id}" + ) + + traceparent = normalized_carrier.get("traceparent") + if not isinstance(traceparent, str): + raise ValueError( + "Phoenix parent span context payload is missing traceparent: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + traceparent_match = _TRACEPARENT_PATTERN.fullmatch(traceparent) + if traceparent_match is None: + raise ValueError( + "Phoenix parent span context payload has invalid traceparent format: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + if traceparent_match.group("version") == "ff": + raise ValueError( + "Phoenix parent span context payload has unsupported traceparent version: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + if traceparent_match.group("trace_id") == "0" * 32: + raise ValueError( + "Phoenix parent span context payload has zero trace_id in traceparent: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + if traceparent_match.group("span_id") == "0" * 16: + raise ValueError( + "Phoenix parent span context payload has zero span_id in traceparent: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + extracted_context = TraceContextTextMapPropagator().extract(carrier=normalized_carrier) + extracted_span_context = get_current_span(extracted_context).get_span_context() + if not extracted_span_context.is_valid or not extracted_span_context.is_remote: + raise ValueError( + "Phoenix parent span context payload could not be restored into a valid parent span: " + f"parent_node_execution_id={parent_node_execution_id}" + ) + + return normalized_carrier + def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]: """Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix.""" @@ -177,6 +269,246 @@ def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues: return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN) +def _resolve_workflow_session_id(trace_info: WorkflowTraceInfo) -> str: + """Resolve the workflow session ID for Phoenix workflow spans.""" + if trace_info.conversation_id: + return trace_info.conversation_id + + parent_workflow_run_id, _ = _resolve_workflow_parent_context(trace_info) + if parent_workflow_run_id: + return parent_workflow_run_id + + return trace_info.workflow_run_id + + +def _resolve_workflow_parent_context(trace_info: BaseTraceInfo) -> tuple[str | None, str | None]: + """Expose the typed parent context already resolved on the trace info.""" + return trace_info.resolved_parent_context + + +def _resolve_workflow_root_trace_id(trace_info: WorkflowTraceInfo) -> str: + """Resolve the canonical root trace ID for Phoenix workflow spans.""" + trace_correlation_override, _ = _resolve_workflow_parent_context(trace_info) + return trace_correlation_override or trace_info.resolved_trace_id or trace_info.workflow_run_id + + +class _NodeExecutionIdentityLike(Protocol): + @property + def node_execution_id(self) -> str | None: ... + + @property + def node_id(self) -> str: ... + + @property + def predecessor_node_id(self) -> str | None: ... + + +class _NodeExecutionLike(_NodeExecutionIdentityLike, Protocol): + @property + def id(self) -> str: ... + + @property + def node_type(self) -> str: ... + + @property + def title(self) -> str | None: ... + + @property + def inputs(self) -> Mapping[str, Any] | None: ... + + @property + def process_data(self) -> Mapping[str, Any] | None: ... + + @property + def outputs(self) -> Mapping[str, Any] | None: ... + + @property + def status(self) -> WorkflowNodeExecutionStatus: ... + + @property + def error(self) -> str | None: ... + + @property + def elapsed_time(self) -> float | None: ... + + @property + def metadata(self) -> Mapping[Any, Any] | None: ... + + @property + def created_at(self) -> datetime | None: ... + + +_PHOENIX_STRUCTURED_NODE_TYPES = frozenset({"start", "end", "loop", "iteration"}) + + +def _resolve_workflow_span_name(trace_info: WorkflowTraceInfo) -> str: + """Resolve the Phoenix workflow span display name.""" + workflow_run_id = trace_info.workflow_run_id.strip() if trace_info.workflow_run_id else "" + if workflow_run_id: + return f"{TraceTaskName.WORKFLOW_TRACE.value}_{workflow_run_id}" + return TraceTaskName.WORKFLOW_TRACE.value + + +def _build_node_title_by_id(trace_info: WorkflowTraceInfo) -> dict[str, str]: + """Build an authoritative node-title index from the persisted workflow graph.""" + workflow_data = trace_info.workflow_data + workflow_graph = getattr(workflow_data, "graph_dict", None) + if not isinstance(workflow_graph, Mapping): + workflow_graph = workflow_data.get("graph") if isinstance(workflow_data, Mapping) else None + if not isinstance(workflow_graph, Mapping): + return {} + + graph_nodes = workflow_graph.get("nodes") + if not isinstance(graph_nodes, Sequence): + return {} + + node_title_by_id: dict[str, str] = {} + for graph_node in graph_nodes: + if not isinstance(graph_node, Mapping): + continue + node_id = graph_node.get("id") + node_data = graph_node.get("data") + if not isinstance(node_id, str) or not isinstance(node_data, Mapping): + continue + node_title = node_data.get("title") + if isinstance(node_title, str) and node_title.strip(): + node_title_by_id[node_id] = node_title.strip() + + return node_title_by_id + + +def _resolve_workflow_node_span_name( + node_execution: _NodeExecutionLike, + node_title_by_id: Mapping[str, str] | None = None, +) -> str: + """Resolve the Phoenix workflow node span display name.""" + node_type = str(node_execution.node_type or "") + graph_node_title = None + if node_title_by_id is not None and isinstance(node_execution.node_id, str): + graph_node_title = node_title_by_id.get(node_execution.node_id) + + node_title = graph_node_title or (node_execution.title.strip() if isinstance(node_execution.title, str) else "") + if node_title: + return f"{node_type}_{node_title}" + return node_type + + +def _get_node_execution_id(node_execution: _NodeExecutionIdentityLike) -> str: + """Return the stable execution identifier for a workflow node execution.""" + return str(getattr(node_execution, "id", None) or node_execution.node_execution_id) + + +def _build_execution_id_by_node_id(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]: + """Index unique workflow graph node ids by execution id. + + This Phoenix-local hierarchy reconstruction intentionally drops ambiguous + node ids instead of guessing based on repository order. That keeps parent + selection deterministic until upstream tracing exposes explicit parent span + data for repeated executions. + """ + execution_id_by_node_id: dict[str, str] = {} + ambiguous_node_ids: set[str] = set() + + for node_execution in node_executions: + node_id = node_execution.node_id + if not isinstance(node_id, str): + continue + execution_id = _get_node_execution_id(node_execution) + + if node_id in ambiguous_node_ids: + continue + + existing_execution_id = execution_id_by_node_id.get(node_id) + if existing_execution_id is None: + execution_id_by_node_id[node_id] = execution_id + continue + + if existing_execution_id != execution_id: + ambiguous_node_ids.add(node_id) + execution_id_by_node_id.pop(node_id, None) + + return execution_id_by_node_id + + +def _build_graph_parent_index(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]: + """Build an execution-id parent index from predecessor node ids.""" + execution_id_by_node_id = _build_execution_id_by_node_id(node_executions) + graph_parent_index: dict[str, str] = {} + + for node_execution in node_executions: + predecessor_node_id = node_execution.predecessor_node_id + if not isinstance(predecessor_node_id, str): + continue + + predecessor_execution_id = execution_id_by_node_id.get(predecessor_node_id) + if predecessor_execution_id is not None: + execution_id = _get_node_execution_id(node_execution) + graph_parent_index[execution_id] = predecessor_execution_id + + return graph_parent_index + + +def _resolve_structured_parent_execution_id( + node_execution: object, execution_id_by_node_id: Mapping[str, str] +) -> str | None: + """Resolve Phoenix-local structured parents from loop/iteration node ids. + + Any execution carrying ``iteration_id`` or ``loop_id`` belongs to an + enclosing structured node. When predecessor node ids are ambiguous because + the graph node repeats inside that structure, Phoenix can still keep the + child span under the enclosing loop/iteration span without relying on + execution-order heuristics. + """ + execution_metadata = getattr(node_execution, "execution_metadata_dict", None) + if not isinstance(execution_metadata, Mapping): + execution_metadata = getattr(node_execution, "metadata", None) + if not isinstance(execution_metadata, Mapping): + execution_metadata = {} + + for enclosing_node_id in ( + getattr(node_execution, "iteration_id", None), + getattr(node_execution, "loop_id", None), + execution_metadata.get("iteration_id"), + execution_metadata.get("loop_id"), + ): + if not isinstance(enclosing_node_id, str): + continue + + enclosing_execution_id = execution_id_by_node_id.get(enclosing_node_id) + if enclosing_execution_id is not None: + return enclosing_execution_id + + return None + + +def _resolve_node_parent( + execution_id: str, + predecessor_execution_id: str | None, + structured_parent_execution_id: str | None, + span_by_execution_id: Mapping[str, Span], + graph_parent_index: Mapping[str, str], + workflow_span: Span, +) -> Span: + """Resolve the parent span for a workflow node execution.""" + if predecessor_execution_id is not None: + predecessor_span = span_by_execution_id.get(predecessor_execution_id) + if predecessor_span is not None: + return predecessor_span + + graph_parent_execution_id = graph_parent_index.get(execution_id) + if graph_parent_execution_id is not None: + graph_parent_span = span_by_execution_id.get(graph_parent_execution_id) + if graph_parent_span is not None: + return graph_parent_span + + if structured_parent_execution_id is not None: + structured_parent_span = span_by_execution_id.get(structured_parent_execution_id) + if structured_parent_span is not None: + return structured_parent_span + + return workflow_span + + class ArizePhoenixDataTrace(BaseTraceInstance): def __init__( self, @@ -189,6 +521,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.propagator = TraceContextTextMapPropagator() self.dify_trace_ids: set[str] = set() + self.root_span_carriers: dict[str, dict[str, str]] = {} + self.carrier: dict[str, str] = {} def trace(self, trace_info: BaseTraceInfo): logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info) @@ -235,13 +569,41 @@ class ArizePhoenixDataTrace(BaseTraceInstance): file_list=safe_json_dumps(file_list), query=trace_info.query or "", ) + workflow_session_id = _resolve_workflow_session_id(trace_info) + parent_workflow_run_id, parent_node_execution_id = _resolve_workflow_parent_context(trace_info) + logger.info( + "[Arize/Phoenix] Workflow session resolution: workflow_run_id=%s conversation_id=%s " + "parent_workflow_run_id=%s parent_node_execution_id=%s resolved_session_id=%s", + trace_info.workflow_run_id, + trace_info.conversation_id, + parent_workflow_run_id, + parent_node_execution_id, + workflow_session_id, + ) - dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id - self.ensure_root_span(dify_trace_id) - root_span_context = self.propagator.extract(carrier=self.carrier) + if parent_node_execution_id: + workflow_parent_carrier = _resolve_published_parent_span_context(parent_node_execution_id) + else: + root_trace_id = _resolve_workflow_root_trace_id(trace_info) + workflow_root_span_name: str | None = trace_info.workflow_run_id + if not isinstance(workflow_root_span_name, str) or not workflow_root_span_name.strip(): + workflow_root_span_name = None + + workflow_parent_carrier = self.ensure_root_span( + root_trace_id, + root_span_name=workflow_root_span_name, + root_span_attributes={ + SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + }, + ) + + workflow_span_context = self.propagator.extract(carrier=workflow_parent_carrier) workflow_span = self.tracer.start_span( - name=TraceTaskName.WORKFLOW_TRACE.value, + name=_resolve_workflow_span_name(trace_info), attributes={ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs), @@ -249,10 +611,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs), SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.METADATA: safe_json_dumps(metadata), - SpanAttributes.SESSION_ID: trace_info.conversation_id or "", + SpanAttributes.SESSION_ID: workflow_session_id or "", }, start_time=datetime_to_nanos(trace_info.start_time), - context=root_span_context, + context=workflow_span_context, ) # Through workflow_run_id, get all_nodes_execution using repository @@ -276,16 +638,50 @@ class ArizePhoenixDataTrace(BaseTraceInstance): workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( workflow_execution_id=trace_info.workflow_run_id ) + node_title_by_id = _build_node_title_by_id(trace_info) + execution_id_by_node_id = _build_execution_id_by_node_id(workflow_node_executions) + graph_parent_index = _build_graph_parent_index(workflow_node_executions) + node_execution_by_execution_id = { + _get_node_execution_id(node_execution): node_execution for node_execution in workflow_node_executions + } + span_by_execution_id: dict[str, Span] = {} + emitting_execution_ids: set[str] = set() + workflow_span_error: Exception | str | None = trace_info.error try: - for node_execution in workflow_node_executions: + + def emit_node_span(node_execution: _NodeExecutionLike) -> Span: + execution_id = _get_node_execution_id(node_execution) + existing_span = span_by_execution_id.get(execution_id) + if existing_span is not None: + return existing_span + + graph_parent_execution_id = graph_parent_index.get(execution_id) + structured_parent_execution_id = _resolve_structured_parent_execution_id( + node_execution, execution_id_by_node_id + ) + + if execution_id not in emitting_execution_ids: + emitting_execution_ids.add(execution_id) + try: + for parent_execution_id in (graph_parent_execution_id, structured_parent_execution_id): + if parent_execution_id is None or parent_execution_id == execution_id: + continue + if parent_execution_id in span_by_execution_id: + continue + parent_node_execution = node_execution_by_execution_id.get(parent_execution_id) + if parent_node_execution is not None: + emit_node_span(parent_node_execution) + finally: + emitting_execution_ids.discard(execution_id) + tenant_id = trace_info.tenant_id # Use from trace_info instead app_id = trace_info.metadata.get("app_id") # Use from trace_info instead inputs_value = node_execution.inputs or {} outputs_value = node_execution.outputs or {} created_at = node_execution.created_at or datetime.now() - elapsed_time = node_execution.elapsed_time + elapsed_time = node_execution.elapsed_time or 0 finished_at = created_at + timedelta(seconds=elapsed_time) process_data = node_execution.process_data or {} @@ -324,9 +720,17 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0) node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0) - workflow_span_context = set_span_in_context(workflow_span) + parent_span = _resolve_node_parent( + execution_id=execution_id, + predecessor_execution_id=None, + structured_parent_execution_id=structured_parent_execution_id, + span_by_execution_id=span_by_execution_id, + graph_parent_index=graph_parent_index, + workflow_span=workflow_span, + ) + workflow_span_context = set_span_in_context(parent_span) node_span = self.tracer.start_span( - name=node_execution.node_type, + name=_resolve_workflow_node_span_name(node_execution, node_title_by_id), attributes={ SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value, SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value), @@ -334,13 +738,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value), SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, SpanAttributes.METADATA: safe_json_dumps(node_metadata), - SpanAttributes.SESSION_ID: trace_info.conversation_id or "", + SpanAttributes.SESSION_ID: workflow_session_id or "", }, start_time=datetime_to_nanos(created_at), context=workflow_span_context, ) - + span_by_execution_id[execution_id] = node_span + node_span_error: Exception | str | None = None try: + if node_execution.node_type == "tool": + parent_span_carrier: dict[str, str] = {} + with use_span(node_span, end_on_exit=False): + self.propagator.inject(carrier=parent_span_carrier) + _publish_parent_span_context(execution_id, parent_span_carrier) + if node_execution.node_type == "llm": llm_attributes: dict[str, Any] = { SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False), @@ -362,17 +773,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) node_span.set_attributes(llm_attributes) + except Exception as e: + node_span_error = e + raise finally: - if node_execution.status == WorkflowNodeExecutionStatus.FAILED: + if node_span_error is not None: + set_span_status(node_span, node_span_error) + elif node_execution.status == WorkflowNodeExecutionStatus.FAILED: set_span_status(node_span, node_execution.error) else: set_span_status(node_span) node_span.end(end_time=datetime_to_nanos(finished_at)) + return node_span + + for node_execution in workflow_node_executions: + emit_node_span(node_execution) + except Exception as e: + workflow_span_error = e + raise finally: - if trace_info.error: - set_span_status(workflow_span, trace_info.error) - else: - set_span_status(workflow_span) + set_span_status(workflow_span, workflow_span_error) workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time)) def message_trace(self, trace_info: MessageTraceInfo): @@ -735,22 +1155,39 @@ class ArizePhoenixDataTrace(BaseTraceInstance): finally: span.end(end_time=datetime_to_nanos(trace_info.end_time)) - def ensure_root_span(self, dify_trace_id: str | None): + def ensure_root_span( + self, + dify_trace_id: str | None, + *, + root_span_name: str | None = None, + root_span_attributes: Mapping[str, AttributeValue] | None = None, + ): """Ensure a unique root span exists for the given Dify trace ID.""" - if str(dify_trace_id) not in self.dify_trace_ids: - self.carrier: dict[str, str] = {} + trace_key = str(dify_trace_id) + if trace_key not in self.dify_trace_ids: + carrier: dict[str, str] = {} - root_span = self.tracer.start_span(name="Dify") - root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value) - root_span.set_attribute("dify_project_name", str(self.project)) - root_span.set_attribute("dify_trace_id", str(dify_trace_id)) + span_name = root_span_name.strip() if isinstance(root_span_name, str) and root_span_name.strip() else "Dify" + root_span_attributes_dict: dict[str, AttributeValue] = { + SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value, + "dify_project_name": str(self.project), + "dify_trace_id": trace_key, + } + if root_span_attributes: + root_span_attributes_dict.update(root_span_attributes) + + root_span = self.tracer.start_span(name=span_name, attributes=root_span_attributes_dict) with use_span(root_span, end_on_exit=False): - self.propagator.inject(carrier=self.carrier) + self.propagator.inject(carrier=carrier) set_span_status(root_span) root_span.end() - self.dify_trace_ids.add(str(dify_trace_id)) + self.dify_trace_ids.add(trace_key) + self.root_span_carriers[trace_key] = carrier + + self.carrier = self.root_span_carriers[trace_key] + return self.carrier def api_check(self): try: diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py index e9ecc2e083..dd260aeee5 100644 --- a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -1,10 +1,21 @@ from datetime import UTC, datetime, timedelta -from typing import cast +from types import SimpleNamespace +from typing import Any, cast from unittest.mock import MagicMock, patch +import dify_trace_arize_phoenix.arize_phoenix_trace as arize_phoenix_trace_module import pytest from dify_trace_arize_phoenix.arize_phoenix_trace import ( + _NODE_TYPE_TO_SPAN_KIND, ArizePhoenixDataTrace, + _build_graph_parent_index, + _get_node_span_kind, + _phoenix_parent_span_redis_key, + _resolve_node_parent, + _resolve_published_parent_span_context, + _resolve_structured_parent_execution_id, + _resolve_workflow_parent_context, + _resolve_workflow_session_id, datetime_to_nanos, error_to_string, safe_json_dumps, @@ -13,6 +24,7 @@ from dify_trace_arize_phoenix.arize_phoenix_trace import ( wrap_span_metadata, ) from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig +from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry.sdk.trace import Tracer from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes from opentelemetry.trace import StatusCode @@ -24,8 +36,12 @@ from core.ops.entities.trace_entity import ( ModerationTraceInfo, SuggestedQuestionTraceInfo, ToolTraceInfo, + TraceTaskName, + WorkflowNodeTraceInfo, WorkflowTraceInfo, ) +from core.ops.exceptions import PendingTraceParentContextError +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes # --- Helpers --- @@ -73,6 +89,80 @@ def _make_message_info(**kwargs): return MessageTraceInfo(**defaults) +def _get_start_span_call(start_span_mock, *, span_name: str): + for call in start_span_mock.call_args_list: + if call.kwargs.get("name") == span_name: + return call + raise AssertionError(f"Could not find start_span call with name={span_name!r}") + + +def _make_node_execution(**kwargs): + defaults = { + "node_type": "tool", + "status": "succeeded", + "inputs": {}, + "outputs": {}, + "created_at": _dt(), + "elapsed_time": 1.0, + "process_data": {}, + "metadata": {}, + "title": "Node", + "id": "node-execution-1", + "node_execution_id": "node-execution-1", + "node_id": "node-1", + "predecessor_node_id": None, + "iteration_id": None, + "loop_id": None, + "error": None, + } + defaults.update(kwargs) + node_execution = MagicMock() + for key, value in defaults.items(): + setattr(node_execution, key, value) + return node_execution + + +def _make_workflow_trace_info(**kwargs) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "workflow-1", + "tenant_id": "tenant-1", + "workflow_run_id": "workflow-run-1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"input": "value"}, + "workflow_run_outputs": {"output": "value"}, + "workflow_run_version": "1.0", + "total_tokens": 10, + "file_list": ["file-1"], + "query": "hello", + "metadata": {"app_id": "app-1"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(kwargs) + return WorkflowTraceInfo(**defaults) + + +def _make_workflow_node_trace_info(**kwargs) -> WorkflowNodeTraceInfo: + defaults = { + "workflow_id": "workflow-1", + "workflow_run_id": "workflow-run-1", + "tenant_id": "tenant-1", + "node_execution_id": "node-execution-1", + "node_id": "node-1", + "node_type": "tool", + "title": "Node 1", + "status": "succeeded", + "elapsed_time": 1.0, + "index": 1, + "metadata": {"app_id": "app-1"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(kwargs) + return WorkflowNodeTraceInfo(**defaults) + + # --- Utility Function Tests --- @@ -143,6 +233,258 @@ def test_wrap_span_metadata(): assert res == {"a": 1, "b": 2, "created_from": "Dify"} +class TestGetNodeSpanKind: + def test_all_node_types_are_mapped_correctly(self): + special_mappings = { + BuiltinNodeTypes.LLM: OpenInferenceSpanKindValues.LLM, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER, + BuiltinNodeTypes.TOOL: OpenInferenceSpanKindValues.TOOL, + BuiltinNodeTypes.AGENT: OpenInferenceSpanKindValues.AGENT, + } + + for node_type in BUILT_IN_NODE_TYPES: + expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN) + actual_span_kind = _get_node_span_kind(node_type) + assert actual_span_kind == expected_span_kind, ( + f"Node type {node_type!r} was mapped to {actual_span_kind}, but {expected_span_kind} was expected." + ) + + def test_unknown_string_defaults_to_chain(self): + assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN + + def test_stale_dataset_retrieval_not_in_mapping(self): + assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND + + +class TestWorkflowSessionResolution: + def test_prefers_conversation_id(self): + info = _make_workflow_trace_info(conversation_id="conversation-1") + + assert _resolve_workflow_session_id(info) == "conversation-1" + + def test_nested_workflow_keeps_own_conversation_id_when_parent_context_exists(self): + info = _make_workflow_trace_info( + conversation_id="conversation-1", + metadata={ + "app_id": "app-1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + ) + + assert _resolve_workflow_session_id(info) == "conversation-1" + + def test_uses_parent_workflow_run_id_for_nested_parent_trace_context(self): + info = _make_workflow_trace_info( + conversation_id=None, + metadata={ + "app_id": "app-1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + ) + + assert _resolve_workflow_session_id(info) == "outer-workflow-run-1" + + def test_falls_back_to_workflow_run_id(self): + info = _make_workflow_trace_info(conversation_id=None) + + assert _resolve_workflow_session_id(info) == "workflow-run-1" + + def test_parent_context_helper_delegates_to_resolved_parent_context(self): + info = MagicMock() + info.resolved_parent_context = ("outer-workflow-run-1", "outer-node-execution-1") + + assert _resolve_workflow_parent_context(info) == info.resolved_parent_context + + +class TestPhoenixParentSpanBridgeHelpers: + def test_parent_span_redis_key_is_stable(self): + assert _phoenix_parent_span_redis_key("outer-node-execution-1") == ( + "trace:phoenix:parent_span:outer-node-execution-1" + ) + + def test_pending_parent_exception_exposes_execution_id(self): + error = PendingTraceParentContextError("outer-node-execution-1") + + assert error.parent_node_execution_id == "outer-node-execution-1" + assert "outer-node-execution-1" in str(error) + + def test_resolve_parent_span_context_rejects_payload_without_traceparent(self, monkeypatch): + mock_redis = MagicMock() + mock_redis.get.return_value = '{"tracestate": "vendor=value"}' + monkeypatch.setattr(arize_phoenix_trace_module, "redis_client", mock_redis) + + with pytest.raises(ValueError, match="traceparent"): + _resolve_published_parent_span_context("outer-node-execution-1") + + @pytest.mark.parametrize( + "stored_payload", + [ + '{"traceparent": ""}', + '{"traceparent": "not-a-traceparent"}', + '{"traceparent": "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb"}', + ], + ) + def test_resolve_parent_span_context_rejects_malformed_traceparent(self, monkeypatch, stored_payload): + mock_redis = MagicMock() + mock_redis.get.return_value = stored_payload + monkeypatch.setattr(arize_phoenix_trace_module, "redis_client", mock_redis) + + with pytest.raises(ValueError, match="traceparent"): + _resolve_published_parent_span_context("outer-node-execution-1") + + +class TestWorkflowHierarchyHelpers: + def test_build_graph_parent_index_uses_predecessor_nodes_without_order_heuristics(self): + later_node = _make_workflow_node_trace_info( + node_execution_id="node-execution-3", + node_id="node-3", + predecessor_node_id="node-2", + index=3, + ) + root_node = _make_workflow_node_trace_info( + node_execution_id="node-execution-1", + node_id="node-1", + predecessor_node_id=None, + index=1, + ) + middle_node = _make_workflow_node_trace_info( + node_execution_id="node-execution-2", + node_id="node-2", + predecessor_node_id="node-1", + index=2, + ) + + graph_parent_index = _build_graph_parent_index([later_node, root_node, middle_node]) + + assert graph_parent_index == { + "node-execution-2": "node-execution-1", + "node-execution-3": "node-execution-2", + } + + def test_build_graph_parent_index_drops_ambiguous_parallel_like_predecessors(self): + first_parallel_node = _make_workflow_node_trace_info( + node_execution_id="parallel-node-execution-1", + node_id="parallel-node-1", + predecessor_node_id=None, + index=1, + parallel_id="parallel-1", + ) + second_parallel_node = _make_workflow_node_trace_info( + node_execution_id="parallel-node-execution-2", + node_id="parallel-node-1", + predecessor_node_id=None, + index=2, + parallel_id="parallel-2", + ) + child_node = _make_workflow_node_trace_info( + node_execution_id="child-node-execution-1", + node_id="child-node-1", + predecessor_node_id="parallel-node-1", + index=3, + ) + + graph_parent_index = _build_graph_parent_index([child_node, first_parallel_node, second_parallel_node]) + + assert graph_parent_index == {} + + def test_resolve_node_parent_prefers_predecessor_span(self): + workflow_span = MagicMock(name="workflow-span") + predecessor_span = MagicMock(name="predecessor-span") + graph_parent_span = MagicMock(name="graph-parent-span") + + parent = _resolve_node_parent( + execution_id="node-execution-2", + predecessor_execution_id="node-execution-1", + structured_parent_execution_id=None, + span_by_execution_id={ + "node-execution-1": predecessor_span, + "node-execution-0": graph_parent_span, + }, + graph_parent_index={ + "node-execution-2": "node-execution-0", + }, + workflow_span=workflow_span, + ) + + assert parent is predecessor_span + + def test_resolve_node_parent_falls_back_to_graph_parent_span(self): + workflow_span = MagicMock(name="workflow-span") + graph_parent_span = MagicMock(name="graph-parent-span") + + parent = _resolve_node_parent( + execution_id="node-execution-2", + predecessor_execution_id="missing-predecessor", + structured_parent_execution_id=None, + span_by_execution_id={ + "node-execution-0": graph_parent_span, + }, + graph_parent_index={ + "node-execution-2": "node-execution-0", + }, + workflow_span=workflow_span, + ) + + assert parent is graph_parent_span + + def test_resolve_node_parent_falls_back_to_workflow_span(self): + workflow_span = MagicMock(name="workflow-span") + + parent = _resolve_node_parent( + execution_id="node-execution-2", + predecessor_execution_id=None, + structured_parent_execution_id=None, + span_by_execution_id={}, + graph_parent_index={}, + workflow_span=workflow_span, + ) + + assert parent is workflow_span + + def test_resolve_structured_parent_execution_id_allows_body_nodes_to_use_enclosing_structure(self): + body_node = _make_workflow_node_trace_info( + node_execution_id="body-execution-1", + node_id="body-node-1", + node_type="tool", + loop_id="loop-node-1", + ) + + structured_parent_execution_id = _resolve_structured_parent_execution_id( + body_node, + execution_id_by_node_id={ + "loop-node-1": "loop-execution-1", + }, + ) + + assert structured_parent_execution_id == "loop-execution-1" + + def test_resolve_structured_parent_execution_id_reads_execution_metadata_dict_for_models(self): + body_node = SimpleNamespace( + node_execution_id="body-execution-1", + node_id="body-node-1", + execution_metadata_dict={ + "iteration_id": "iteration-node-1", + "loop_id": "loop-node-1", + }, + ) + + structured_parent_execution_id = _resolve_structured_parent_execution_id( + body_node, + execution_id_by_node_id={ + "iteration-node-1": "iteration-execution-1", + "loop-node-1": "loop-execution-1", + }, + ) + + assert structured_parent_execution_id == "iteration-execution-1" + + @patch("dify_trace_arize_phoenix.arize_phoenix_trace.GrpcOTLPSpanExporter") @patch("dify_trace_arize_phoenix.arize_phoenix_trace.trace_sdk.TracerProvider") def test_setup_tracer_arize(mock_provider, mock_exporter): @@ -173,12 +515,17 @@ def test_setup_tracer_exception(): @pytest.fixture def trace_instance(): - with patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup: + with ( + patch("dify_trace_arize_phoenix.arize_phoenix_trace.setup_tracer") as mock_setup, + patch("dify_trace_arize_phoenix.arize_phoenix_trace.redis_client", new=MagicMock()) as mock_redis, + ): mock_tracer = MagicMock(spec=Tracer) mock_processor = MagicMock() mock_setup.return_value = (mock_tracer, mock_processor) config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") - return ArizePhoenixDataTrace(config) + instance = ArizePhoenixDataTrace(config) + cast(Any, instance)._mock_redis_client = mock_redis + yield instance def test_trace_dispatch(trace_instance): @@ -273,23 +620,821 @@ def test_workflow_trace_no_app_id(mock_db, trace_instance): @patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") -def test_message_trace_success(mock_db, trace_instance): +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_uses_canonical_root_context_for_top_level_workflow( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info(message_id="message-1", workflow_run_id="workflow-run-1") + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + root_carrier = {} + root_context = object() + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value=root_carrier) as mock_ensure_root_span, + patch.object(trace_instance.propagator, "extract", return_value=root_context) as mock_extract, + ): + trace_instance.workflow_trace(info) + + mock_ensure_root_span.assert_called_once_with( + info.resolved_trace_id, + root_span_name="workflow-run-1", + root_span_attributes={ + SpanAttributes.INPUT_VALUE: safe_json_dumps(info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: "application/json", + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: "application/json", + }, + ) + mock_extract.assert_called_once_with(carrier=root_carrier) + workflow_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow_workflow-run-1") + assert workflow_span_call.kwargs["context"] is root_context + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_uses_workflow_run_id_for_root_span_and_populates_root_inputs_outputs( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + workflow_run_inputs={"prompt": "hello"}, + workflow_run_outputs={"result": "world"}, + metadata={ + "app_id": "app1", + "app_name": "Workflow Name", + }, + workflow_run_id="workflow-run-xyz", + ) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + with patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()): + trace_instance.workflow_trace(info) + + root_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow-run-xyz") + assert root_span_call.kwargs["attributes"][SpanAttributes.INPUT_VALUE] == safe_json_dumps(info.workflow_run_inputs) + assert root_span_call.kwargs["attributes"][SpanAttributes.OUTPUT_VALUE] == safe_json_dumps( + info.workflow_run_outputs + ) + assert root_span_call.kwargs["attributes"][SpanAttributes.INPUT_MIME_TYPE] == "application/json" + assert root_span_call.kwargs["attributes"][SpanAttributes.OUTPUT_MIME_TYPE] == "application/json" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_falls_back_to_dify_name_when_workflow_run_id_is_blank( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + metadata={ + "app_id": "app1", + "app_name": "", + }, + workflow_run_id="", + ) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + with patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()): + trace_instance.workflow_trace(info) + + root_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="Dify") + assert root_span_call.kwargs["attributes"]["dify_trace_id"] == "" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_reuses_upstream_parent_workflow_context_when_no_parent_node_execution_id_is_available( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + message_id="message-1", + workflow_run_id="workflow-run-1", + metadata={ + "app_id": "app1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + }, + }, + ) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + parent_carrier = {} + parent_context = object() + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value=parent_carrier) as mock_ensure_root_span, + patch.object(trace_instance.propagator, "extract", return_value=parent_context) as mock_extract, + ): + trace_instance.workflow_trace(info) + + mock_ensure_root_span.assert_called_once_with( + "outer-workflow-run-1", + root_span_name="workflow-run-1", + root_span_attributes={ + SpanAttributes.INPUT_VALUE: safe_json_dumps(info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: "application/json", + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: "application/json", + }, + ) + mock_extract.assert_called_once_with(carrier=parent_carrier) + workflow_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow_workflow-run-1") + assert workflow_span_call.kwargs["context"] is parent_context + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_uses_published_parent_node_context_for_nested_workflow( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + message_id="message-1", + workflow_run_id="workflow-run-1", + metadata={ + "app_id": "app1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + ) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + stored_carrier = '{"traceparent":"00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01"}' + trace_instance._mock_redis_client.get.return_value = stored_carrier + parent_context = object() + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span") as mock_ensure_root_span, + patch.object(trace_instance.propagator, "extract", return_value=parent_context) as mock_extract, + ): + trace_instance.workflow_trace(info) + + trace_instance._mock_redis_client.get.assert_called_once_with( + _phoenix_parent_span_redis_key("outer-node-execution-1") + ) + mock_ensure_root_span.assert_not_called() + mock_extract.assert_called_once_with( + carrier={"traceparent": "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01"} + ) + workflow_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow_workflow-run-1") + assert workflow_span_call.kwargs["context"] is parent_context + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_raises_pending_parent_error_when_parent_node_context_is_missing( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + message_id="message-1", + workflow_run_id="workflow-run-1", + metadata={ + "app_id": "app1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + ) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + trace_instance._mock_redis_client.get.return_value = None + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span") as mock_ensure_root_span, + pytest.raises(PendingTraceParentContextError) as exc_info, + ): + trace_instance.workflow_trace(info) + + assert exc_info.value.parent_node_execution_id == "outer-node-execution-1" + trace_instance._mock_redis_client.get.assert_called_once_with( + _phoenix_parent_span_redis_key("outer-node-execution-1") + ) + mock_ensure_root_span.assert_not_called() + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_uses_parent_workflow_run_id_for_workflow_and_nodes_when_nested_context_is_present( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + conversation_id=None, + metadata={ + "app_id": "app1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + }, + }, + ) + repo = MagicMock() + node_execution = MagicMock() + node_execution.node_type = "tool" + node_execution.status = "succeeded" + node_execution.inputs = {"tool_input": "value"} + node_execution.outputs = {"tool_output": "value"} + node_execution.created_at = _dt() + node_execution.elapsed_time = 1.0 + node_execution.process_data = {} + node_execution.metadata = {} + node_execution.title = "Tool node" + node_execution.id = "node-1" + node_execution.error = None + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + with patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()): + trace_instance.workflow_trace(info) + + workflow_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow_r1") + node_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="tool_Tool node") + + assert workflow_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "outer-workflow-run-1" + assert node_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "outer-workflow-run-1" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_falls_back_to_node_type_when_node_title_is_blank( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + node_execution = _make_node_execution( + id="node-execution-1", + node_execution_id="node-execution-1", + node_id="node-1", + node_type="tool", + title=" ", + ) + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + with patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()): + trace_instance.workflow_trace(info) + + node_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="tool") + assert node_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "r1" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_prefers_workflow_graph_node_title_over_execution_title( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + workflow_data={ + "graph": { + "nodes": [ + { + "id": "nested-tool-node", + "data": { + "type": "tool", + "title": "nested workflow tool", + }, + } + ] + } + } + ) + repo = MagicMock() + node_execution = _make_node_execution( + id="node-execution-1", + node_execution_id="node-execution-1", + node_id="nested-tool-node", + node_type="tool", + title="2", + ) + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + with patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()): + trace_instance.workflow_trace(info) + + node_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="tool_nested workflow tool") + assert node_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "r1" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_keeps_nested_conversation_session_while_reusing_parent_root_context( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info( + conversation_id="conversation-1", + message_id="message-1", + workflow_run_id="workflow-run-1", + metadata={ + "app_id": "app1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + }, + }, + ) + repo = MagicMock() + node_execution = _make_node_execution( + id="node-execution-1", + node_execution_id="node-execution-1", + node_id="node-1", + node_type="tool", + ) + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + parent_carrier = {} + parent_context = object() + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value=parent_carrier) as mock_ensure_root_span, + patch.object(trace_instance.propagator, "extract", return_value=parent_context) as mock_extract, + ): + trace_instance.workflow_trace(info) + + mock_ensure_root_span.assert_called_once_with( + "outer-workflow-run-1", + root_span_name="workflow-run-1", + root_span_attributes={ + SpanAttributes.INPUT_VALUE: safe_json_dumps(info.workflow_run_inputs), + SpanAttributes.INPUT_MIME_TYPE: "application/json", + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(info.workflow_run_outputs), + SpanAttributes.OUTPUT_MIME_TYPE: "application/json", + }, + ) + mock_extract.assert_called_once_with(carrier=parent_carrier) + workflow_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="workflow_workflow-run-1") + node_span_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="tool_Node") + assert workflow_span_call.kwargs["context"] is parent_context + assert workflow_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "conversation-1" + assert node_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "conversation-1" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_publishes_tool_node_parent_span_context_to_redis( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + node_execution = _make_node_execution( + id="tool-execution-1", + node_execution_id="tool-execution-1", + node_id="tool-node-1", + node_type="tool", + ) + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + tool_span = MagicMock(name="tool-span") + tool_span._context_label = "tool" + trace_instance.tracer.start_span.side_effect = [workflow_span, tool_span] + + def inject_side_effect(carrier): + carrier["traceparent"] = "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01" + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch.object(trace_instance.propagator, "inject", side_effect=inject_side_effect) as mock_inject, + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + ): + trace_instance.workflow_trace(info) + + mock_inject.assert_called_once() + trace_instance._mock_redis_client.setex.assert_called_once_with( + _phoenix_parent_span_redis_key("tool-execution-1"), + 300, + '{"traceparent": "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01"}', + ) + + +@pytest.mark.parametrize( + ("failing_step", "expected_message"), + [ + ("inject", "inject failed"), + ("publish", "publish failed"), + ], +) +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_cleans_up_tool_span_when_parent_context_publish_fails( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, + failing_step, + expected_message, +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + node_execution = _make_node_execution( + id="tool-execution-1", + node_execution_id="tool-execution-1", + node_id="tool-node-1", + node_type="tool", + ) + repo.get_by_workflow_execution.return_value = [node_execution] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + tool_span = MagicMock(name="tool-span") + tool_span._context_label = "tool" + trace_instance.tracer.start_span.side_effect = [workflow_span, tool_span] + + inject_side_effect = None + if failing_step == "inject": + inject_side_effect = RuntimeError(expected_message) + else: + trace_instance._mock_redis_client.setex.side_effect = RuntimeError(expected_message) + + def inject_side_effect(carrier): + carrier["traceparent"] = "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa-bbbbbbbbbbbbbbbb-01" + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch.object(trace_instance.propagator, "inject", side_effect=inject_side_effect), + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + pytest.raises(RuntimeError, match=expected_message), + ): + trace_instance.workflow_trace(info) + + tool_span.end.assert_called_once() + workflow_span.end.assert_called_once() + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_parents_serial_nodes_to_resolved_predecessor_span( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + second_node = _make_node_execution( + id="node-execution-2", + node_execution_id="node-execution-2", + node_id="node-2", + node_type="llm", + predecessor_node_id="node-1", + process_data={ + "prompts": [{"role": "user", "content": "hi"}], + "model_provider": "openai", + "model_name": "gpt-4", + }, + ) + first_node = _make_node_execution( + id="node-execution-1", + node_execution_id="node-execution-1", + node_id="node-1", + node_type="tool", + ) + repo.get_by_workflow_execution.return_value = [second_node, first_node] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + first_node_span = MagicMock(name="first-node-span") + first_node_span._context_label = "node-1" + second_node_span = MagicMock(name="second-node-span") + second_node_span._context_label = "node-2" + trace_instance.tracer.start_span.side_effect = [workflow_span, first_node_span, second_node_span] + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + ): + trace_instance.workflow_trace(info) + + first_node_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="tool_Node") + second_node_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="llm_Node") + assert first_node_call.kwargs["context"] == "context:workflow" + assert second_node_call.kwargs["context"] == "context:node-1" + + +@pytest.mark.parametrize( + ("enclosing_node_type", "structured_field"), + [ + ("loop", "loop_id"), + ("iteration", "iteration_id"), + ], +) +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_parents_structured_start_nodes_to_enclosing_structure_span( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, + enclosing_node_type, + structured_field, +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + enclosing_node = _make_node_execution( + id=f"{enclosing_node_type}-execution-1", + node_execution_id=f"{enclosing_node_type}-execution-1", + node_id=f"{enclosing_node_type}-node-1", + node_type=enclosing_node_type, + ) + structured_kwargs = {structured_field: f"{enclosing_node_type}-node-1"} + start_node = _make_node_execution( + id="start-execution-1", + node_execution_id="start-execution-1", + node_id="start-node-1", + node_type="start", + **structured_kwargs, + ) + repo.get_by_workflow_execution.return_value = [start_node, enclosing_node] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + enclosing_node_span = MagicMock(name="enclosing-node-span") + enclosing_node_span._context_label = enclosing_node_type + start_node_span = MagicMock(name="start-node-span") + start_node_span._context_label = "start" + trace_instance.tracer.start_span.side_effect = [workflow_span, enclosing_node_span, start_node_span] + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + ): + trace_instance.workflow_trace(info) + + start_node_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="start_Node") + assert start_node_call.kwargs["context"] == f"context:{enclosing_node_type}" + + +@pytest.mark.parametrize( + ("enclosing_node_type", "structured_field"), + [ + ("loop", "loop_id"), + ("iteration", "iteration_id"), + ], +) +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_keeps_duplicate_body_node_children_under_enclosing_structure( + mock_sessionmaker, + mock_repo_factory, + mock_db, + trace_instance, + enclosing_node_type, + structured_field, +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + enclosing_node = _make_node_execution( + id=f"{enclosing_node_type}-execution-1", + node_execution_id=f"{enclosing_node_type}-execution-1", + node_id=f"{enclosing_node_type}-node-1", + node_type=enclosing_node_type, + ) + structured_kwargs = {structured_field: f"{enclosing_node_type}-node-1"} + repeated_body_node_1 = _make_node_execution( + id="body-execution-1", + node_execution_id="body-execution-1", + node_id="body-node-1", + node_type="tool", + **structured_kwargs, + ) + repeated_body_node_2 = _make_node_execution( + id="body-execution-2", + node_execution_id="body-execution-2", + node_id="body-node-1", + node_type="tool", + **structured_kwargs, + ) + child_node = _make_node_execution( + id="child-execution-1", + node_execution_id="child-execution-1", + node_id="child-node-1", + node_type="llm", + predecessor_node_id="body-node-1", + process_data={ + "prompts": [{"role": "user", "content": "hi"}], + "model_provider": "openai", + "model_name": "gpt-4", + }, + **structured_kwargs, + ) + repo.get_by_workflow_execution.return_value = [ + child_node, + repeated_body_node_1, + repeated_body_node_2, + enclosing_node, + ] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + enclosing_node_span = MagicMock(name="enclosing-node-span") + enclosing_node_span._context_label = enclosing_node_type + child_node_span = MagicMock(name="child-node-span") + child_node_span._context_label = "child" + repeated_body_node_1_span = MagicMock(name="repeated-body-node-1-span") + repeated_body_node_1_span._context_label = "body-1" + repeated_body_node_2_span = MagicMock(name="repeated-body-node-2-span") + repeated_body_node_2_span._context_label = "body-2" + trace_instance.tracer.start_span.side_effect = [ + workflow_span, + enclosing_node_span, + child_node_span, + repeated_body_node_1_span, + repeated_body_node_2_span, + ] + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + ): + trace_instance.workflow_trace(info) + + child_node_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="llm_Node") + assert child_node_call.kwargs["context"] == f"context:{enclosing_node_type}" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.sessionmaker") +def test_workflow_trace_falls_back_to_workflow_span_for_parallel_like_ambiguous_predecessors( + mock_sessionmaker, mock_repo_factory, mock_db, trace_instance +): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + child_node = _make_node_execution( + id="child-execution-1", + node_execution_id="child-execution-1", + node_id="child-node-1", + node_type="llm", + predecessor_node_id="parallel-node-1", + process_data={ + "prompts": [{"role": "user", "content": "hi"}], + "model_provider": "openai", + "model_name": "gpt-4", + }, + ) + first_parallel_node = _make_node_execution( + id="parallel-execution-1", + node_execution_id="parallel-execution-1", + node_id="parallel-node-1", + node_type="tool", + parallel_id="parallel-1", + ) + second_parallel_node = _make_node_execution( + id="parallel-execution-2", + node_execution_id="parallel-execution-2", + node_id="parallel-node-1", + node_type="tool", + parallel_id="parallel-2", + ) + repo.get_by_workflow_execution.return_value = [child_node, first_parallel_node, second_parallel_node] + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + workflow_span = MagicMock(name="workflow-span") + workflow_span._context_label = "workflow" + child_node_span = MagicMock(name="child-node-span") + child_node_span._context_label = "child" + first_parallel_node_span = MagicMock(name="first-parallel-node-span") + first_parallel_node_span._context_label = "parallel-1" + second_parallel_node_span = MagicMock(name="second-parallel-node-span") + second_parallel_node_span._context_label = "parallel-2" + trace_instance.tracer.start_span.side_effect = [ + workflow_span, + child_node_span, + first_parallel_node_span, + second_parallel_node_span, + ] + + with ( + patch.object(trace_instance, "get_service_account_with_tenant", return_value=MagicMock()), + patch.object(trace_instance, "ensure_root_span", return_value={}), + patch.object(trace_instance.propagator, "extract", return_value="root-context"), + patch( + "dify_trace_arize_phoenix.arize_phoenix_trace.set_span_in_context", + side_effect=lambda span: f"context:{span._context_label}", + ), + ): + trace_instance.workflow_trace(info) + + child_node_call = _get_start_span_call(trace_instance.tracer.start_span, span_name="llm_Node") + assert child_node_call.kwargs["context"] == "context:workflow" + + +@patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") +def test_message_trace_keeps_conversation_id_as_session(mock_db, trace_instance): mock_db.engine = MagicMock() info = _make_message_info() info.message_data = MagicMock() - info.message_data.from_account_id = "acc1" + info.message_data.conversation_id = "conversation-2" + info.message_data.from_account_id = "acc2" info.message_data.from_end_user_id = None - info.message_data.query = "q" - info.message_data.answer = "a" - info.message_data.status = "s" - info.message_data.model_id = "m" - info.message_data.model_provider = "p" + info.message_data.query = "q2" + info.message_data.answer = "a2" + info.message_data.status = "s2" + info.message_data.model_id = "m2" + info.message_data.model_provider = "p2" info.message_data.message_metadata = "{}" info.message_data.error = None info.error = None + root_span = MagicMock() + message_span = MagicMock() + llm_span = MagicMock() + trace_instance.tracer.start_span.side_effect = [root_span, message_span, llm_span] + trace_instance.message_trace(info) - assert trace_instance.tracer.start_span.call_count >= 1 + + message_span_call = _get_start_span_call( + trace_instance.tracer.start_span, span_name=TraceTaskName.MESSAGE_TRACE.value + ) + assert message_span_call.kwargs["attributes"][SpanAttributes.SESSION_ID] == "conversation-2" @patch("dify_trace_arize_phoenix.arize_phoenix_trace.db") @@ -397,3 +1542,30 @@ def test_api_check_success(trace_instance): def test_ensure_root_span_basic(trace_instance): trace_instance.ensure_root_span("tid") assert "tid" in trace_instance.dify_trace_ids + + +def test_ensure_root_span_uses_custom_name_and_attributes(trace_instance): + root_attributes = { + SpanAttributes.INPUT_VALUE: '{"input":"value"}', + SpanAttributes.OUTPUT_VALUE: '{"output":"value"}', + } + + trace_instance.ensure_root_span("tid", root_span_name="Workflow Name", root_span_attributes=root_attributes) + + trace_instance.tracer.start_span.assert_called_once_with( + name="Workflow Name", + attributes={ + SpanAttributes.OPENINFERENCE_SPAN_KIND: "CHAIN", + "dify_project_name": "p", + "dify_trace_id": "tid", + SpanAttributes.INPUT_VALUE: '{"input":"value"}', + SpanAttributes.OUTPUT_VALUE: '{"output":"value"}', + }, + ) + + +def test_ensure_root_span_falls_back_to_dify_name_when_custom_name_is_blank(trace_instance): + trace_instance.ensure_root_span("tid", root_span_name=" ") + + trace_instance.tracer.start_span.assert_called_once() + assert trace_instance.tracer.start_span.call_args.kwargs["name"] == "Dify" diff --git a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py b/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py deleted file mode 100644 index a01c63ae61..0000000000 --- a/api/providers/trace/trace-arize-phoenix/tests/unit_tests/test_arize_phoenix_trace.py +++ /dev/null @@ -1,36 +0,0 @@ -from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from openinference.semconv.trace import OpenInferenceSpanKindValues - -from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes - - -class TestGetNodeSpanKind: - """Tests for _get_node_span_kind helper.""" - - def test_all_node_types_are_mapped_correctly(self): - """Ensure every built-in node type is mapped to the correct span kind.""" - # Mappings for node types that have a specialised span kind. - special_mappings = { - BuiltinNodeTypes.LLM: OpenInferenceSpanKindValues.LLM, - BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER, - BuiltinNodeTypes.TOOL: OpenInferenceSpanKindValues.TOOL, - BuiltinNodeTypes.AGENT: OpenInferenceSpanKindValues.AGENT, - } - - # Test that every built-in node type is mapped to the correct span kind. - # Node types not in `special_mappings` should default to CHAIN. - for node_type in BUILT_IN_NODE_TYPES: - expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN) - actual_span_kind = _get_node_span_kind(node_type) - assert actual_span_kind == expected_span_kind, ( - f"Node type {node_type!r} was mapped to {actual_span_kind}, but {expected_span_kind} was expected." - ) - - def test_unknown_string_defaults_to_chain(self): - """An unrecognised node type string should still return CHAIN.""" - assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN - - def test_stale_dataset_retrieval_not_in_mapping(self): - """The old 'dataset_retrieval' string was never a valid NodeType value; - make sure it is not present in the mapping dictionary.""" - assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py index f73ba01c8b..be9d64ae01 100644 --- a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py @@ -65,35 +65,18 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): } file_list = values.get("file_list", []) if isinstance(v, str): - if field_name == "inputs": - return { - "messages": { - "role": "user", - "content": v, - "usage_metadata": usage_metadata, - "file_list": file_list, - }, - } - elif field_name == "outputs": - return { - "choices": { - "role": "ai", - "content": v, - "usage_metadata": usage_metadata, - "file_list": file_list, - }, - } - elif isinstance(v, list): - data = {} - if len(v) > 0 and isinstance(v[0], dict): - # rename text to content - v = replace_text_with_content(data=v) - if field_name == "inputs": - data = { - "messages": v, + match field_name: + case "inputs": + return { + "messages": { + "role": "user", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, } - elif field_name == "outputs": - data = { + case "outputs": + return { "choices": { "role": "ai", "content": v, @@ -101,6 +84,29 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): "file_list": file_list, }, } + case _: + pass + elif isinstance(v, list): + data = {} + if len(v) > 0 and isinstance(v[0], dict): + # rename text to content + v = replace_text_with_content(data=v) + match field_name: + case "inputs": + data = { + "messages": v, + } + case "outputs": + data = { + "choices": { + "role": "ai", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, + } + case _: + pass return data else: return { diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py index 145bd70dbc..045ec44e4e 100644 --- a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/langsmith_trace.py @@ -64,7 +64,9 @@ class LangSmithDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id + # trace_id must equal the root run's run_id (LangSmith protocol); external trace_id + # cannot be used here as it would cause HTTP 400. + trace_id = trace_info.message_id or trace_info.workflow_run_id if trace_info.start_time is None: trace_info.start_time = datetime.now() message_dotted_order = ( @@ -77,6 +79,8 @@ class LangSmithDataTrace(BaseTraceInstance): ) metadata = trace_info.metadata metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id + if trace_info.trace_id: + metadata["external_trace_id"] = trace_info.trace_id if trace_info.message_id: message_run = LangSmithRunModel( diff --git a/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py index ee59acb17e..edc4aafd87 100644 --- a/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py +++ b/api/providers/trace/trace-langsmith/tests/unit_tests/langsmith_trace/test_langsmith_trace.py @@ -208,13 +208,17 @@ def test_workflow_trace(trace_instance, monkeypatch: pytest.MonkeyPatch): assert call_args[0].id == "msg-1" assert call_args[0].name == TraceTaskName.MESSAGE_TRACE + # trace_id must equal root run's id (message_id), not the external trace_id "trace-1" + assert call_args[0].trace_id == "msg-1" assert call_args[1].id == "run-1" assert call_args[1].name == TraceTaskName.WORKFLOW_TRACE assert call_args[1].parent_run_id == "msg-1" + assert call_args[1].trace_id == "msg-1" assert call_args[2].id == "node-llm" assert call_args[2].run_type == LangSmithRunType.llm + assert call_args[2].trace_id == "msg-1" assert call_args[3].id == "node-other" assert call_args[3].run_type == LangSmithRunType.tool @@ -604,3 +608,83 @@ def test_get_project_url_error(trace_instance): trace_instance.langsmith_client.get_run_url.side_effect = Exception("error") with pytest.raises(ValueError, match="LangSmith get run url failed: error"): trace_instance.get_project_url() + + +def _make_workflow_trace_info( + *, message_id: str | None, workflow_run_id: str, trace_id: str | None +) -> WorkflowTraceInfo: + workflow_data = MagicMock() + workflow_data.created_at = _dt() + workflow_data.finished_at = _dt() + timedelta(seconds=1) + return WorkflowTraceInfo( + tenant_id="tenant-1", + workflow_id="wf-1", + workflow_run_id=workflow_run_id, + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_status="succeeded", + workflow_run_version="1.0", + workflow_run_elapsed_time=1.0, + total_tokens=0, + file_list=[], + query="q", + message_id=message_id, + conversation_id="conv-1" if message_id else None, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id=trace_id, + metadata={"app_id": "app-1"}, + workflow_app_log_id=None, + error=None, + workflow_data=workflow_data, + ) + + +def _patch_workflow_trace_deps(monkeypatch, trace_instance): + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine")) + repo = MagicMock() + repo.get_by_workflow_execution.return_value = [] + factory = MagicMock() + factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", factory) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + trace_instance.add_run = MagicMock() + + +def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypatch): + """Chatflow with external trace_id: LangSmith trace_id must be message_id, not external.""" + trace_info = _make_workflow_trace_info( + message_id="msg-abc", + workflow_run_id="run-xyz", + trace_id="external-999", + ) + _patch_workflow_trace_deps(monkeypatch, trace_instance) + + trace_instance.workflow_trace(trace_info) + + calls = [c[0][0] for c in trace_instance.add_run.call_args_list] + # message run (root) and workflow run (child) must both use message_id as trace_id + assert calls[0].id == "msg-abc" + assert calls[0].trace_id == "msg-abc" + assert calls[1].id == "run-xyz" + assert calls[1].trace_id == "msg-abc" + # external_trace_id preserved in metadata + assert trace_info.metadata.get("external_trace_id") == "external-999" + + +def test_workflow_trace_id_pure_workflow_uses_run_id(trace_instance, monkeypatch): + """Pure workflow (no message_id) with external trace_id: trace_id must be workflow_run_id.""" + trace_info = _make_workflow_trace_info( + message_id=None, + workflow_run_id="run-xyz", + trace_id="external-999", + ) + _patch_workflow_trace_deps(monkeypatch, trace_instance) + + trace_instance.workflow_trace(trace_info) + + calls = [c[0][0] for c in trace_instance.add_run.call_args_list] + # workflow run is the root; trace_id must equal its run_id + assert calls[0].id == "run-xyz" + assert calls[0].trace_id == "run-xyz" diff --git a/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py index 843c495d82..d6998f6672 100644 --- a/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py +++ b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py @@ -81,14 +81,15 @@ class OpenSearchConfig(BaseModel): pool_maxsize=20, ) - if self.auth_method == "basic": - logger.info("Using basic authentication for OpenSearch Vector DB") + match self.auth_method: + case AuthMethod.BASIC: + logger.info("Using basic authentication for OpenSearch Vector DB") - params["http_auth"] = (self.user, self.password) - elif self.auth_method == "aws_managed_iam": - logger.info("Using AWS managed IAM role for OpenSearch Vector DB") + params["http_auth"] = (self.user, self.password) + case AuthMethod.AWS_MANAGED_IAM: + logger.info("Using AWS managed IAM role for OpenSearch Vector DB") - params["http_auth"] = self.create_aws_managed_iam_auth() + params["http_auth"] = self.create_aws_managed_iam_auth() return params diff --git a/api/pyproject.toml b/api/pyproject.toml index ed03dc99fe..51cd7516c3 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -6,7 +6,7 @@ requires-python = "~=3.12.0" dependencies = [ # Legacy: mature and widely deployed "bleach>=6.3.0", - "boto3>=1.43.3", + "boto3>=1.43.6", "celery>=5.6.3", "croniter>=6.2.2", "flask>=3.1.3,<4.0.0", @@ -14,8 +14,8 @@ dependencies = [ "gevent>=26.4.0", "gevent-websocket>=0.10.1", "gmpy2>=2.3.0", - "google-api-python-client>=2.195.0", - "gunicorn>=25.3.0", + "google-api-python-client>=2.196.0", + "gunicorn>=26.0.0", "psycogreen>=1.0.2", "psycopg2-binary>=2.9.12", "python-socketio>=5.13.0", @@ -30,7 +30,7 @@ dependencies = [ "flask-migrate>=4.1.0,<5.0.0", "flask-orjson>=2.0.0,<3.0.0", "flask-restx>=1.3.2,<2.0.0", - "google-cloud-aiplatform>=1.149.0,<2.0.0", + "google-cloud-aiplatform>=1.151.0,<2.0.0", "httpx[socks]>=0.28.1,<1.0.0", "opentelemetry-distro>=0.62b1,<1.0.0", "opentelemetry-instrumentation-celery>=0.62b0,<1.0.0", @@ -43,7 +43,7 @@ dependencies = [ "resend>=2.27.0,<3.0.0", # Emerging: newer and fast-moving, use compatible pins "fastopenapi[flask]~=0.7.0", - "graphon~=0.2.2", + "graphon~=0.3.1", "httpx-sse~=0.4.0", "json-repair~=0.59.4", ] @@ -190,7 +190,7 @@ storage = [ "google-cloud-storage>=3.10.1", "opendal>=0.46.0", "oss2>=2.19.1", - "supabase>=2.29.0", + "supabase>=2.30.0", "tos>=2.9.0", ] diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 0229a1f43a..aa6b8ffc6e 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -425,7 +425,7 @@ class AppAnnotationService: return {"deleted_count": deleted_count} @classmethod - def batch_import_app_annotations(cls, app_id, file: FileStorage): + def batch_import_app_annotations(cls, app_id: str, file: FileStorage): """ Batch import annotations from CSV file with enhanced security checks. diff --git a/api/services/app_service.py b/api/services/app_service.py index a046b909b3..6716833f6c 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,9 +1,10 @@ import json import logging -from typing import Any, TypedDict, cast +from typing import Any, Literal, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination +from pydantic import BaseModel, Field from sqlalchemy import select from configs import dify_config @@ -31,39 +32,59 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t logger = logging.getLogger(__name__) +class AppListParams(BaseModel): + page: int = Field(default=1, ge=1) + limit: int = Field(default=20, ge=1, le=100) + mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = "all" + name: str | None = None + tag_ids: list[str] | None = None + is_created_by_me: bool | None = None + + +class CreateAppParams(BaseModel): + name: str = Field(min_length=1) + description: str | None = None + mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] + icon_type: str | None = None + icon: str | None = None + icon_background: str | None = None + api_rph: int = 0 + api_rpm: int = 0 + max_active_requests: int | None = None + + class AppService: - def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict[str, Any]) -> Pagination | None: + def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None: """ Get app list with pagination :param user_id: user id :param tenant_id: tenant id - :param args: request args + :param params: query parameters :return: """ filters = [App.tenant_id == tenant_id, App.is_universal == False] - if args["mode"] == "workflow": + if params.mode == "workflow": filters.append(App.mode == AppMode.WORKFLOW) - elif args["mode"] == "completion": + elif params.mode == "completion": filters.append(App.mode == AppMode.COMPLETION) - elif args["mode"] == "chat": + elif params.mode == "chat": filters.append(App.mode == AppMode.CHAT) - elif args["mode"] == "advanced-chat": + elif params.mode == "advanced-chat": filters.append(App.mode == AppMode.ADVANCED_CHAT) - elif args["mode"] == "agent-chat": + elif params.mode == "agent-chat": filters.append(App.mode == AppMode.AGENT_CHAT) - if args.get("is_created_by_me", False): + if params.is_created_by_me: filters.append(App.created_by == user_id) - if args.get("name"): + if params.name: from libs.helper import escape_like_pattern - name = args["name"][:30] + name = params.name[:30] escaped_name = escape_like_pattern(name) filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\")) - # Check if tag_ids is not empty to avoid WHERE false condition - if args.get("tag_ids") and len(args["tag_ids"]) > 0: - target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) + if params.tag_ids and len(params.tag_ids) > 0: + target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids) if target_ids and len(target_ids) > 0: filters.append(App.id.in_(target_ids)) else: @@ -71,21 +92,21 @@ class AppService: app_models = db.paginate( sa.select(App).where(*filters).order_by(App.created_at.desc()), - page=args["page"], - per_page=args["limit"], + page=params.page, + per_page=params.limit, error_out=False, ) return app_models - def create_app(self, tenant_id: str, args: dict[str, Any], account: Account) -> App: + def create_app(self, tenant_id: str, params: CreateAppParams, account: Account) -> App: """ Create app :param tenant_id: tenant id - :param args: request args + :param params: app creation parameters :param account: Account instance """ - app_mode = AppMode.value_of(args["mode"]) + app_mode = AppMode.value_of(params.mode) app_template = default_app_templates[app_mode] # get model config @@ -143,15 +164,16 @@ class AppService: default_model_config["model"] = json.dumps(default_model_dict) app = App(**app_template["app"]) - app.name = args["name"] - app.description = args.get("description", "") - app.mode = args["mode"] - app.icon_type = args.get("icon_type", "emoji") - app.icon = args["icon"] - app.icon_background = args["icon_background"] + app.name = params.name + app.description = params.description or "" + app.mode = app_mode + app.icon_type = IconType(params.icon_type) if params.icon_type else IconType.EMOJI + app.icon = params.icon + app.icon_background = params.icon_background app.tenant_id = tenant_id - app.api_rph = args.get("api_rph", 0) - app.api_rpm = args.get("api_rpm", 0) + app.api_rph = params.api_rph + app.api_rpm = params.api_rpm + app.max_active_requests = params.max_active_requests app.created_by = account.id app.updated_by = account.id diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 60948e652b..c80b2f43fd 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -54,7 +54,7 @@ class AudioService: if extension not in [f"audio/{ext}" for ext in AUDIO_EXTENSIONS]: raise UnsupportedAudioTypeServiceError() - file_content = file.read() + file_content = file.stream.read() file_size = len(file_content) if file_size > FILE_SIZE_LIMIT: diff --git a/api/services/credit_pool_service.py b/api/services/credit_pool_service.py index 2d210db121..1f419d7a5b 100644 --- a/api/services/credit_pool_service.py +++ b/api/services/credit_pool_service.py @@ -1,7 +1,7 @@ import logging -from sqlalchemy import select, update -from sqlalchemy.orm import sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.errors.error import QuotaExceededError @@ -13,6 +13,18 @@ logger = logging.getLogger(__name__) class CreditPoolService: + @staticmethod + def _get_locked_pool(session: Session, tenant_id: str, pool_type: str) -> TenantCreditPool | None: + return session.scalar( + select(TenantCreditPool) + .where( + TenantCreditPool.tenant_id == tenant_id, + TenantCreditPool.pool_type == pool_type, + ) + .limit(1) + .with_for_update() + ) + @classmethod def create_default_pool(cls, tenant_id: str) -> TenantCreditPool: """create default credit pool for new tenant""" @@ -59,31 +71,57 @@ class CreditPoolService: credits_required: int, pool_type: str = "trial", ) -> int: - """check and deduct credits, returns actual credits deducted""" - - pool = cls.get_pool(tenant_id, pool_type) - if not pool: - raise QuotaExceededError("Credit pool not found") - - if pool.remaining_credits <= 0: - raise QuotaExceededError("No credits remaining") - - # deduct all remaining credits if less than required - actual_credits = min(credits_required, pool.remaining_credits) + """Deduct exactly the requested credits or raise without mutating the pool.""" + if credits_required <= 0: + return 0 try: - with sessionmaker(db.engine).begin() as session: - stmt = ( - update(TenantCreditPool) - .where( - TenantCreditPool.tenant_id == tenant_id, - TenantCreditPool.pool_type == pool_type, - ) - .values(quota_used=TenantCreditPool.quota_used + actual_credits) - ) - session.execute(stmt) + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type) + if not pool: + raise QuotaExceededError("Credit pool not found") + + remaining_credits = pool.remaining_credits + if remaining_credits <= 0: + raise QuotaExceededError("No credits remaining") + if remaining_credits < credits_required: + raise QuotaExceededError("Insufficient credits remaining") + + pool.quota_used += credits_required + except QuotaExceededError: + raise except Exception: logger.exception("Failed to deduct credits for tenant %s", tenant_id) raise QuotaExceededError("Failed to deduct credits") - return actual_credits + return credits_required + + @classmethod + def deduct_credits_capped( + cls, + tenant_id: str, + credits_required: int, + pool_type: str = "trial", + ) -> int: + """Deduct up to the available balance and return the actual deducted credits.""" + if credits_required <= 0: + return 0 + + try: + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + pool = cls._get_locked_pool(session=session, tenant_id=tenant_id, pool_type=pool_type) + if not pool: + logger.warning("Credit pool not found, tenant_id=%s, pool_type=%s", tenant_id, pool_type) + return 0 + + deducted_credits = min(credits_required, pool.remaining_credits) + if deducted_credits <= 0: + return 0 + + pool.quota_used += deducted_credits + return deducted_credits + except QuotaExceededError: + raise + except Exception: + logger.exception("Failed to deduct capped credits for tenant %s", tenant_id) + raise QuotaExceededError("Failed to deduct credits") diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index eef38f1ce2..4f5a95dcde 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,9 +7,10 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal, TypedDict, cast +from typing import Annotated, Any, Literal, TypedDict, cast import sqlalchemy as sa +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator from redis.exceptions import LockNotOwnedError from sqlalchemy import delete, exists, func, select, update from sqlalchemy.orm import Session, sessionmaker @@ -108,7 +109,7 @@ logger = logging.getLogger(__name__) class ProcessRulesDict(TypedDict): - mode: str + mode: ProcessRuleMode rules: dict[str, Any] @@ -117,6 +118,86 @@ class AutoDisableLogsDict(TypedDict): count: int +class _EstimatePreProcessingRule(BaseModel): + id: str = Field(min_length=1) + enabled: bool + + @field_validator("id") + @classmethod + def _validate_id(cls, v: str) -> str: + if v not in DatasetProcessRule.PRE_PROCESSING_RULES: + raise ValueError("Process rule pre_processing_rules id is invalid") + return v + + +class _EstimateSegmentation(BaseModel): + separator: str = Field(min_length=1) + max_tokens: int = Field(gt=0) + + +class _EstimateRules(BaseModel): + pre_processing_rules: list[_EstimatePreProcessingRule] + segmentation: _EstimateSegmentation + + @field_validator("pre_processing_rules") + @classmethod + def _deduplicate(cls, v: list[_EstimatePreProcessingRule]) -> list[_EstimatePreProcessingRule]: + seen: dict[str, _EstimatePreProcessingRule] = {} + for rule in v: + seen[rule.id] = rule + return list(seen.values()) + + +class _SummaryIndexSettingDisabled(BaseModel): + enable: Literal[False] = False + + +class _SummaryIndexSettingEnabled(BaseModel): + enable: Literal[True] + model_name: str = Field(min_length=1) + model_provider_name: str = Field(min_length=1) + + +_SummaryIndexSetting = Annotated[ + _SummaryIndexSettingDisabled | _SummaryIndexSettingEnabled, + Field(discriminator="enable"), +] + + +class _AutomaticProcessRule(BaseModel): + model_config = ConfigDict(extra="allow") + + mode: Literal[ProcessRuleMode.AUTOMATIC] + summary_index_setting: _SummaryIndexSetting | None = None + + +class _CustomProcessRule(BaseModel): + model_config = ConfigDict(extra="allow") + + mode: Literal[ProcessRuleMode.CUSTOM] + rules: _EstimateRules + summary_index_setting: _SummaryIndexSetting | None = None + + +class _HierarchicalProcessRule(BaseModel): + model_config = ConfigDict(extra="allow") + + mode: Literal[ProcessRuleMode.HIERARCHICAL] + rules: _EstimateRules + summary_index_setting: _SummaryIndexSetting | None = None + + +_EstimateProcessRule = Annotated[ + _AutomaticProcessRule | _CustomProcessRule | _HierarchicalProcessRule, + Field(discriminator="mode"), +] + + +class _EstimateArgs(BaseModel): + info_list: dict[str, Any] + process_rule: _EstimateProcessRule + + class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): @@ -204,7 +285,7 @@ class DatasetService: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict or {} else: - mode = str(DocumentService.DEFAULT_RULES["mode"]) + mode = ProcessRuleMode(DocumentService.DEFAULT_RULES["mode"]) rules = dict(DocumentService.DEFAULT_RULES.get("rules") or {}) return {"mode": mode, "rules": rules} @@ -1984,7 +2065,7 @@ class DocumentService: if process_rule.rules: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule.mode, + mode=ProcessRuleMode(process_rule.mode), rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) @@ -1995,7 +2076,7 @@ class DocumentService: elif process_rule.mode == ProcessRuleMode.AUTOMATIC: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule.mode, + mode=ProcessRuleMode.AUTOMATIC, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) @@ -2572,14 +2653,14 @@ class DocumentService: if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule.mode, + mode=ProcessRuleMode(process_rule.mode), rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) elif process_rule.mode == ProcessRuleMode.AUTOMATIC: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule.mode, + mode=ProcessRuleMode.AUTOMATIC, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) @@ -2851,94 +2932,16 @@ class DocumentService: @classmethod def estimate_args_validate(cls, args: dict[str, Any]): - if "info_list" not in args or not args["info_list"]: - raise ValueError("Data source info is required") - - if not isinstance(args["info_list"], dict): - raise ValueError("Data info is invalid") - - if "process_rule" not in args or not args["process_rule"]: - raise ValueError("Process rule is required") - - if not isinstance(args["process_rule"], dict): - raise ValueError("Process rule is invalid") - - if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: - raise ValueError("Process rule mode is required") - - if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: - raise ValueError("Process rule mode is invalid") - - if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC: - args["process_rule"]["rules"] = {} - else: - if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: - raise ValueError("Process rule rules is required") - - if not isinstance(args["process_rule"]["rules"], dict): - raise ValueError("Process rule rules is invalid") - - if ( - "pre_processing_rules" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["pre_processing_rules"] is None - ): - raise ValueError("Process rule pre_processing_rules is required") - - if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): - raise ValueError("Process rule pre_processing_rules is invalid") - - unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: - if "id" not in pre_processing_rule or not pre_processing_rule["id"]: - raise ValueError("Process rule pre_processing_rules id is required") - - if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: - raise ValueError("Process rule pre_processing_rules id is invalid") - - if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: - raise ValueError("Process rule pre_processing_rules enabled is required") - - if not isinstance(pre_processing_rule["enabled"], bool): - raise ValueError("Process rule pre_processing_rules enabled is invalid") - - unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - - args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - - if ( - "segmentation" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["segmentation"] is None - ): - raise ValueError("Process rule segmentation is required") - - if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): - raise ValueError("Process rule segmentation is invalid") - - if ( - "separator" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["separator"] - ): - raise ValueError("Process rule segmentation separator is required") - - if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): - raise ValueError("Process rule segmentation separator is invalid") - - if ( - "max_tokens" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] - ): - raise ValueError("Process rule segmentation max_tokens is required") - - if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): - raise ValueError("Process rule segmentation max_tokens is invalid") - - # valid summary index setting - summary_index_setting = args["process_rule"].get("summary_index_setting") - if summary_index_setting and summary_index_setting.get("enable"): - if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]: - raise ValueError("Summary index model name is required") - if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]: - raise ValueError("Summary index model provider name is required") + try: + validated = _EstimateArgs.model_validate(args) + except ValidationError as e: + first = e.errors()[0] + original = first.get("ctx", {}).get("error") + raise ValueError(str(original) if isinstance(original, ValueError) else first["msg"]) from e + process_rule_dict = validated.process_rule.model_dump(exclude_none=True) + if validated.process_rule.mode == ProcessRuleMode.AUTOMATIC: + process_rule_dict["rules"] = {} + args["process_rule"] = process_rule_dict @staticmethod def batch_update_document_status( diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 16d2597963..a88614b251 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -166,7 +166,7 @@ class SystemFeatureModel(BaseModel): enable_email_code_login: bool = False enable_email_password_login: bool = True enable_social_oauth_login: bool = False - enable_collaboration_mode: bool = False + enable_collaboration_mode: bool = True is_allow_register: bool = False is_allow_create_workspace: bool = False is_email_setup: bool = False diff --git a/api/services/trigger/trigger_subscription_builder_service.py b/api/services/trigger/trigger_subscription_builder_service.py index 889717df72..cff735b39d 100644 --- a/api/services/trigger/trigger_subscription_builder_service.py +++ b/api/services/trigger/trigger_subscription_builder_service.py @@ -121,9 +121,7 @@ class TriggerSubscriptionBuilderService: if not subscription_builder.name: raise ValueError("Subscription builder name is required") - credential_type = CredentialType.of( - subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value - ) + credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED) if credential_type == CredentialType.UNAUTHORIZED: # manually create TriggerProviderService.add_trigger_subscription( @@ -321,9 +319,7 @@ class TriggerSubscriptionBuilderService: raise ValueError("Subscription builder name is required") # Build - credential_type = CredentialType.of( - subscription_builder.credential_type or CredentialType.UNAUTHORIZED.value - ) + credential_type = CredentialType.of(subscription_builder.credential_type or CredentialType.UNAUTHORIZED) if credential_type == CredentialType.UNAUTHORIZED: # manually create TriggerProviderService.add_trigger_subscription( diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 5d99900a04..592f678421 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -402,7 +402,7 @@ class WebhookService: for name, file in files.items(): if file and file.filename: try: - file_content = file.read() + file_content = file.stream.read() mimetype = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream" file_obj = cls._create_file_from_binary(file_content, mimetype, webhook_trigger) processed_files[name] = file_obj.to_dict() diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index b71dbcad80..af25edd2bd 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -157,8 +157,8 @@ class DraftVarLoader(VariableLoader): # This approach reduces loading time by querying external systems concurrently. with ThreadPoolExecutor(max_workers=10) as executor: offloaded_variables = executor.map(self._load_offloaded_variable, offloaded_draft_vars) - for selector, variable in offloaded_variables: - variable_by_selector[selector] = variable + for selector, offloaded_variable in offloaded_variables: + variable_by_selector[selector] = offloaded_variable return list(variable_by_selector.values()) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 68cf74ec9f..c9e6ed4c4d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1302,8 +1302,13 @@ class WorkflowService: ) rendered_content = node.render_form_content_before_submission() + selected_action = next( + (user_action for user_action in node_data.user_actions if user_action.id == action), + None, + ) outputs: dict[str, Any] = dict(normalized_form_inputs) outputs["__action_id"] = action + outputs["__action_value"] = selected_action.title if selected_action else "" outputs["__rendered_content"] = node.render_form_content_with_outputs( rendered_content, outputs, @@ -1497,7 +1502,7 @@ class WorkflowService: adapt_human_input_node_data_for_graph(node_config["data"])) node = HumanInputNode( node_id=node_config["id"], - config=node_data, + data=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, runtime=DifyHumanInputNodeRuntime(run_context), diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index c95b8db078..49fe68ad7e 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -1,11 +1,31 @@ +""" +Celery task for asynchronous ops trace dispatch. + +Trace providers may report explicitly retryable dispatch failures through the +core retryable exception contract. The task preserves the payload file only +when Celery accepts the retry request; successful dispatches and terminal +failures clean up the stored payload. + +One concrete producer today is Phoenix nested workflow tracing. The outer +workflow tool span publishes a restorable parent span context asynchronously, +while the nested workflow trace may be picked up by Celery first. In that +ordering window, the provider raises a retryable core exception instead of +dropping the trace or emitting it under the wrong parent. The task intentionally +does not know that the provider is Phoenix; it only honors the core retryable +dispatch contract. +""" + import json import logging from celery import shared_task +from celery.exceptions import Retry from flask import current_app +from configs import dify_config from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY from core.ops.entities.trace_entity import trace_info_info_map +from core.ops.exceptions import RetryableTraceDispatchError from core.rag.models.document import Document from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -14,9 +34,17 @@ from models.workflow import WorkflowRun logger = logging.getLogger(__name__) +_RETRYABLE_TRACE_DISPATCH_LIMIT = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES +_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS -@shared_task(queue="ops_trace") -def process_trace_tasks(file_info): + +@shared_task( + queue="ops_trace", + bind=True, + max_retries=_RETRYABLE_TRACE_DISPATCH_LIMIT, + default_retry_delay=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS, +) +def process_trace_tasks(self, file_info): """ Async process trace tasks Usage: process_trace_tasks.delay(tasks_data) @@ -29,6 +57,7 @@ def process_trace_tasks(file_info): file_data = json.loads(storage.load(file_path)) trace_info = file_data.get("trace_info") trace_info_type = file_data.get("trace_info_type") + enterprise_trace_dispatched = bool(file_data.get("_enterprise_trace_dispatched")) trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) if trace_info.get("message_data"): @@ -38,6 +67,8 @@ def process_trace_tasks(file_info): if trace_info.get("documents"): trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] + should_delete_file = True + try: trace_type = trace_info_info_map.get(trace_info_type) if trace_type: @@ -45,30 +76,66 @@ def process_trace_tasks(file_info): from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled - if is_ee_telemetry_enabled(): + if is_ee_telemetry_enabled() and not enterprise_trace_dispatched: from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace try: EnterpriseOtelTrace().trace(trace_info) except Exception: logger.exception("Enterprise trace failed for app_id: %s", app_id) + else: + file_data["_enterprise_trace_dispatched"] = True + enterprise_trace_dispatched = True if trace_instance: with current_app.app_context(): trace_instance.trace(trace_info) logger.info("Processing trace tasks success, app_id: %s", app_id) + except RetryableTraceDispatchError as e: + # Retryable dispatch failures represent a transient provider-side + # ordering gap, not corrupt payload data. Keep the payload only after + # Celery accepts the retry request; otherwise this attempt becomes a + # terminal failure and the stored file is cleaned up in `finally`. + # + # Enterprise telemetry runs before provider dispatch. If it already ran + # and provider dispatch asks for a retry, persist that private flag so + # the next attempt does not emit the same enterprise trace twice. + if self.request.retries >= _RETRYABLE_TRACE_DISPATCH_LIMIT: + logger.exception("Retryable trace dispatch budget exhausted, app_id: %s", app_id) + failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" + redis_client.incr(failed_key) + else: + logger.warning( + "Retryable trace dispatch failure, scheduling retry %s/%s for app_id %s: %s", + self.request.retries + 1, + _RETRYABLE_TRACE_DISPATCH_LIMIT, + app_id, + e, + ) + try: + if enterprise_trace_dispatched: + storage.save(file_path, json.dumps(file_data).encode("utf-8")) + raise self.retry(exc=e, countdown=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS) + except Retry: + should_delete_file = False + raise + except Exception: + logger.exception("Failed to schedule trace dispatch retry, app_id: %s", app_id) + failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" + redis_client.incr(failed_key) except Exception as e: logger.exception("Processing trace tasks failed, app_id: %s", app_id) failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" redis_client.incr(failed_key) finally: - try: - storage.delete(file_path) - except Exception as e: - logger.warning( - "Failed to delete trace file %s for app_id %s: %s", - file_path, - app_id, - e, - ) + if should_delete_file: + try: + storage.delete(file_path) + except Exception as e: + logger.warning( + "Failed to delete trace file %s for app_id %s: %s", + file_path, + app_id, + e, + ) diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 2c1e667c58..b9f09ccadd 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -73,7 +73,7 @@ def test_node_integration_minimal_stream(mocker: MockerFixture): node = DatasourceNode( node_id="n", - config=DatasourceNodeData( + data=DatasourceNodeData( type="datasource", version="1", title="Datasource", diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index a9a2617bae..a77fe5970a 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from graphon.model_runtime.entities.model_entities import ModelType from models.provider import ProviderType @@ -15,8 +15,9 @@ def get_mocked_fetch_model_config( mode: str, credentials: dict, ): - model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") - model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) + model_assembly = create_plugin_model_assembly(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") + model_provider_factory = model_assembly.model_provider_factory + model_type_instance = model_assembly.create_model_type_instance(provider=provider, model_type=ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( tenant_id="1", diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index aaa6092993..9345113aa3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -45,7 +45,7 @@ def init_code_node(code_config: dict): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], @@ -66,7 +66,7 @@ def init_code_node(code_config: dict): node = CodeNode( node_id=str(uuid.uuid4()), - config=CodeNodeData.model_validate(code_config["data"]), + data=CodeNodeData.model_validate(code_config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, code_executor=node_factory._code_executor, diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index b9f7b9575b..7cd7f50b77 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -55,7 +55,7 @@ def init_http_node(config: dict): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], @@ -76,7 +76,7 @@ def init_http_node(config: dict): node = HttpRequestNode( node_id=str(uuid.uuid4()), - config=HttpRequestNodeData.model_validate(config["data"]), + data=HttpRequestNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, @@ -204,7 +204,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from graphon.runtime import VariablePool # Create variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="test", files=[]), user_inputs={}, environment_variables=[], @@ -702,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock): ) # Create independent variable pool for this test only - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], @@ -724,7 +724,7 @@ def test_nested_object_variable_selector(setup_http_mock): node = HttpRequestNode( node_id=str(uuid.uuid4()), - config=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), + data=HttpRequestNodeData.model_validate(graph_config["nodes"][1]["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 3eead70163..5b7790f6f4 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -53,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode: ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="aaa", app_id=app_id, @@ -77,7 +77,7 @@ def init_llm_node(config: dict) -> LLMNode: node = LLMNode( node_id=str(uuid.uuid4()), - config=LLMNodeData.model_validate(config["data"]), + data=LLMNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), @@ -91,7 +91,11 @@ def init_llm_node(config: dict) -> LLMNode: return node -def test_execute_llm(): +def _mock_db_session_close(monkeypatch) -> None: + monkeypatch.setattr(db.session, "close", MagicMock()) + + +def test_execute_llm(monkeypatch): node = init_llm_node( config={ "id": "llm", @@ -118,7 +122,7 @@ def test_execute_llm(): }, ) - db.session.close = MagicMock() + _mock_db_session_close(monkeypatch) def build_mock_model_instance() -> MagicMock: from decimal import Decimal @@ -195,7 +199,7 @@ def test_execute_llm(): assert item.node_run_result.outputs.get("usage", {})["total_tokens"] > 0 -def test_execute_llm_with_jinja2(): +def test_execute_llm_with_jinja2(monkeypatch): """ Test execute LLM node with jinja2 """ @@ -233,8 +237,7 @@ def test_execute_llm_with_jinja2(): }, ) - # Mock db.session.close() - db.session.close = MagicMock() + _mock_db_session_close(monkeypatch) def build_mock_model_instance() -> MagicMock: from decimal import Decimal diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index f2eabb86c3..fc230a2a68 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -56,7 +56,7 @@ def init_parameter_extractor_node(config: dict, memory=None): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa" ), @@ -71,7 +71,7 @@ def init_parameter_extractor_node(config: dict, memory=None): node = ParameterExtractorNode( node_id=str(uuid.uuid4()), - config=ParameterExtractorNodeData.model_validate(config["data"]), + data=ParameterExtractorNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), @@ -83,7 +83,11 @@ def init_parameter_extractor_node(config: dict, memory=None): return node -def test_function_calling_parameter_extractor(setup_model_mock): +def _mock_db_session_close(monkeypatch) -> None: + monkeypatch.setattr(db.session, "close", MagicMock()) + + +def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch): """ Test function calling for parameter extractor. """ @@ -114,7 +118,7 @@ def test_function_calling_parameter_extractor(setup_model_mock): mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, )() - db.session.close = MagicMock() + _mock_db_session_close(monkeypatch) result = node._run() @@ -124,7 +128,7 @@ def test_function_calling_parameter_extractor(setup_model_mock): assert result.outputs.get("__reason") == None -def test_instructions(setup_model_mock): +def test_instructions(setup_model_mock, monkeypatch): """ Test chat parameter extractor. """ @@ -155,7 +159,7 @@ def test_instructions(setup_model_mock): mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, )() - db.session.close = MagicMock() + _mock_db_session_close(monkeypatch) result = node._run() @@ -174,7 +178,7 @@ def test_instructions(setup_model_mock): assert "what's the weather in SF" in prompt.get("text") -def test_chat_parameter_extractor(setup_model_mock): +def test_chat_parameter_extractor(setup_model_mock, monkeypatch): """ Test chat parameter extractor. """ @@ -205,7 +209,7 @@ def test_chat_parameter_extractor(setup_model_mock): mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, )() - db.session.close = MagicMock() + _mock_db_session_close(monkeypatch) result = node._run() @@ -225,7 +229,7 @@ def test_chat_parameter_extractor(setup_model_mock): assert '\n{"type": "object"' in prompt.get("text") -def test_completion_parameter_extractor(setup_model_mock): +def test_completion_parameter_extractor(setup_model_mock, monkeypatch): """ Test completion parameter extractor. """ @@ -256,7 +260,7 @@ def test_completion_parameter_extractor(setup_model_mock): mode="completion", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, )() - db.session.close = MagicMock() + _mock_db_session_close(monkeypatch) result = node._run() @@ -350,7 +354,7 @@ def test_extract_json_from_tool_call(): assert result["location"] == "kawaii" -def test_chat_parameter_extractor_with_memory(setup_model_mock): +def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch): """ Test chat parameter extractor with memory. """ @@ -382,7 +386,7 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock): mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, )() - db.session.close = MagicMock() + _mock_db_session_close(monkeypatch) result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index e2e0723fb8..80489e6809 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -66,7 +66,7 @@ def test_execute_template_transform(): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], @@ -88,7 +88,7 @@ def test_execute_template_transform(): node = TemplateTransformNode( node_id=str(uuid.uuid4()), - config=TemplateTransformNodeData.model_validate(config["data"]), + data=TemplateTransformNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, jinja2_template_renderer=_SimpleJinja2Renderer(), diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 493330e02b..78c12e7ea5 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -43,7 +43,7 @@ def init_tool_node(config: dict): ) # construct variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], @@ -64,7 +64,7 @@ def init_tool_node(config: dict): node = ToolNode( node_id=str(uuid.uuid4()), - config=ToolNodeData.model_validate(config["data"]), + data=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py index 290be87697..a071d22ee9 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -168,6 +168,7 @@ def test_node_variable_collection_get_success( account, tenant = create_console_account_and_tenant(db_session_with_containers) app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123") + node_variable_id = node_variable.id _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other") response = test_client_with_containers.get( @@ -178,7 +179,7 @@ def test_node_variable_collection_get_success( assert response.status_code == 200 payload = response.get_json() assert payload is not None - assert [item["id"] for item in payload["items"]] == [node_variable.id] + assert [item["id"] for item in payload["items"]] == [node_variable_id] def test_node_variable_collection_get_invalid_node_id( @@ -377,6 +378,7 @@ def test_system_variable_collection_get( account, tenant = create_console_account_and_tenant(db_session_with_containers) app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW) variable = _create_system_variable(db_session_with_containers, app.id, account.id) + variable_id = variable.id response = test_client_with_containers.get( f"/console/api/apps/{app.id}/workflows/draft/system-variables", @@ -386,7 +388,7 @@ def test_system_variable_collection_get( assert response.status_code == 200 payload = response.get_json() assert payload is not None - assert [item["id"] for item in payload["items"]] == [variable.id] + assert [item["id"] for item in payload["items"]] == [variable_id] def test_environment_variable_collection_get( diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py index 81b5423261..f2c45f76da 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_oauth.py @@ -17,6 +17,8 @@ def test_get_oauth_url_successful( test_client_with_containers: FlaskClient, ) -> None: account, tenant = create_console_account_and_tenant(db_session_with_containers) + tenant_id = tenant.id + current_tenant_id = account.current_tenant_id provider = MagicMock() provider.get_authorization_url.return_value = "http://oauth.provider/auth" @@ -29,7 +31,7 @@ def test_get_oauth_url_successful( headers=authenticate_console_client(test_client_with_containers, account), ) - assert tenant.id == account.current_tenant_id + assert tenant_id == current_tenant_id assert response.status_code == 200 assert response.get_json() == {"data": "http://oauth.provider/auth"} provider.get_authorization_url.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index d017e8f2bd..5fc3b3084a 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from sqlalchemy.orm import Session from controllers.console.auth.error import ( EmailCodeError, @@ -20,13 +21,15 @@ from controllers.console.auth.forgot_password import ( ForgotPasswordSendEmailApi, ) from controllers.console.error import AccountNotFound, EmailSendIpLimitError +from tests.test_containers_integration_tests.controllers.console.helpers import ensure_dify_setup class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask, db_session_with_containers: Session): + ensure_dify_setup(db_session_with_containers) return flask_app_with_containers @pytest.fixture @@ -139,7 +142,8 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask, db_session_with_containers: Session): + ensure_dify_setup(db_session_with_containers) return flask_app_with_containers @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @@ -322,7 +326,8 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask, db_session_with_containers: Session): + ensure_dify_setup(db_session_with_containers) return flask_app_with_containers @pytest.fixture diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index c17a83cad3..ba59780d59 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -3,6 +3,7 @@ from __future__ import annotations from datetime import datetime +from types import SimpleNamespace from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -44,6 +45,35 @@ def unwrap(func): return func +def make_node_execution(**overrides): + payload = { + "id": "node-exec-1", + "index": 1, + "predecessor_node_id": None, + "node_id": "node1", + "node_type": "start", + "title": "Start", + "inputs_dict": {"query": "hello"}, + "process_data_dict": {}, + "outputs_dict": {"answer": "world"}, + "status": "succeeded", + "error": None, + "elapsed_time": 1.0, + "execution_metadata_dict": {}, + "extras": {}, + "created_at": datetime(2026, 1, 1, 0, 0, 0), + "created_by_role": "account", + "created_by_account": None, + "created_by_end_user": None, + "finished_at": datetime(2026, 1, 1, 0, 0, 1), + "inputs_truncated": False, + "outputs_truncated": False, + "process_data_truncated": False, + } + payload.update(overrides) + return SimpleNamespace(**payload) + + class TestDraftWorkflowApi: @pytest.fixture def app(self, flask_app_with_containers: Flask): @@ -743,7 +773,7 @@ class TestRagPipelineWorkflowLastRunApi: pipeline = MagicMock() workflow = MagicMock() - node_exec = MagicMock() + node_exec = make_node_execution() service = MagicMock() service.get_draft_workflow.return_value = workflow @@ -757,7 +787,9 @@ class TestRagPipelineWorkflowLastRunApi: ), ): result = method(api, pipeline, "node1") - assert result == node_exec + assert result["id"] == "node-exec-1" + assert result["inputs"] == {"query": "hello"} + assert result["outputs"] == {"answer": "world"} def test_last_run_not_found(self, app: Flask): api = RagPipelineWorkflowLastRunApi() @@ -799,7 +831,7 @@ class TestRagPipelineDatasourceVariableApi: } service = MagicMock() - service.set_datasource_variables.return_value = MagicMock() + service.set_datasource_variables.return_value = make_node_execution(node_id="n1") with ( app.test_request_context("/", json=payload), @@ -814,4 +846,5 @@ class TestRagPipelineDatasourceVariableApi: ), ): result = method(api, pipeline) - assert result is not None + assert result["node_id"] == "n1" + assert result["process_data"] == {} diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index bd13527e14..66b3392a4b 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -210,7 +210,9 @@ class TestPauseStatePersistenceLayerTestContainers: execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4()) # Create variable pool - variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id)) + variable_pool = VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id=execution_id) + ) if variables: for (node_id, var_key), value in variables.items(): variable_pool.add([node_id, var_key], value) diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 54cf179341..a69dd99adc 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -66,7 +66,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( workflow_execution_id=workflow_execution_id, app_id=app_id, @@ -102,7 +102,7 @@ def _build_graph( start_data = StartNodeData(title="start", variables=[]) start_node = StartNode( node_id="start", - config=start_data, + data=start_data, graph_init_params=params, graph_runtime_state=runtime_state, ) @@ -117,7 +117,7 @@ def _build_graph( ) human_node = HumanInputNode( node_id="human", - config=human_data, + data=human_data, graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, @@ -131,7 +131,7 @@ def _build_graph( ) end_node = EndNode( node_id="end", - config=end_data, + data=end_data, graph_init_params=params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index cbd939c7a4..670b4d00da 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -11,7 +11,7 @@ from models.enums import ConversationFromSource, MessageFileBelongsTo from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought from services.account_service import AccountService, TenantService from services.agent_service import AgentService -from services.app_service import AppService +from services.app_service import AppService, CreateAppParams from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -119,16 +119,16 @@ class TestAgentService: tenant = account.current_tenant # Create app with realistic data - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "agent-chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="agent-chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 95fc73f45a..bc75562d15 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -9,7 +9,7 @@ from models import Account from models.enums import ConversationFromSource, InvokeFrom from models.model import MessageAnnotation from services.annotation_service import AppAnnotationService -from services.app_service import AppService +from services.app_service import AppService, CreateAppParams from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -86,16 +86,16 @@ class TestAnnotationService: tenant = account.current_tenant # Setup app creation arguments - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) # Create app app_service = AppService() diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index a5ec06dc13..c77bbd3e44 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -37,7 +37,7 @@ from services.app_dsl_service import ( PendingData, _check_version_compatibility, ) -from services.app_service import AppService +from services.app_service import AppService, CreateAppParams from tests.test_containers_integration_tests.helpers import generate_valid_password _DEFAULT_TENANT_ID = "00000000-0000-0000-0000-000000000001" @@ -147,16 +147,16 @@ class TestAppDslService: ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) return app, account diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index e2fe6c8476..8be4c040b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -1,4 +1,5 @@ import uuid +from typing import Literal from unittest.mock import ANY, MagicMock, patch import pytest @@ -133,7 +134,10 @@ class TestAppGenerateService: } def _create_test_app_and_account( - self, db_session_with_containers: Session, mock_external_service_dependencies, mode="chat" + self, + db_session_with_containers: Session, + mock_external_service_dependencies, + mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = "chat", ): """ Helper method to create a test app and account for testing. @@ -165,20 +169,20 @@ class TestAppGenerateService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant - # Create app with realistic data - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": mode, - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - "max_active_requests": 5, - } + from services.app_service import AppService, CreateAppParams - from services.app_service import AppService + # Create app with realistic data + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode=mode, + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + max_active_requests=5, + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 837b63d1ea..c37fce296f 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from pydantic import ValidationError from sqlalchemy.orm import Session from constants.model_template import default_app_templates @@ -12,7 +13,7 @@ from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password # Delay import of AppService to avoid circular dependency -# from services.app_service import AppService +# from services.app_service import AppService, AppListParams, CreateAppParams class TestAppService: @@ -64,34 +65,34 @@ class TestAppService: tenant = account.current_tenant # Setup app creation arguments - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } + # Import here to avoid circular dependency + from services.app_service import AppService, CreateAppParams + + app_params = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) # Create app - # Import here to avoid circular dependency - from services.app_service import AppService - app_service = AppService() - app = app_service.create_app(tenant.id, app_args, account) + app = app_service.create_app(tenant.id, app_params, account) # Verify app was created correctly - assert app.name == app_args["name"] - assert app.description == app_args["description"] - assert app.mode == app_args["mode"] - assert app.icon_type == app_args["icon_type"] - assert app.icon == app_args["icon"] - assert app.icon_background == app_args["icon_background"] + assert app.name == app_params.name + assert app.description == app_params.description + assert app.mode == app_params.mode + assert app.icon_type == app_params.icon_type + assert app.icon == app_params.icon + assert app.icon_background == app_params.icon_background assert app.tenant_id == tenant.id - assert app.api_rph == app_args["api_rph"] - assert app.api_rpm == app_args["api_rpm"] + assert app.api_rph == app_params.api_rph + assert app.api_rpm == app_params.api_rpm assert app.created_by == account.id assert app.updated_by == account.id assert app.status == "normal" @@ -120,7 +121,7 @@ class TestAppService: tenant = account.current_tenant # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams app_service = AppService() @@ -129,20 +130,20 @@ class TestAppService: app_modes = [v.value for v in default_app_templates] for mode in app_modes: - app_args = { - "name": f"{fake.company()} {mode}", - "description": f"Test app for {mode} mode", - "mode": mode, - "icon_type": "emoji", - "icon": "🚀", - "icon_background": "#4ECDC4", - } + app_params = CreateAppParams( + name=f"{fake.company()} {mode}", + description=f"Test app for {mode} mode", + mode=mode, + icon_type="emoji", + icon="🚀", + icon_background="#4ECDC4", + ) - app = app_service.create_app(tenant.id, app_args, account) + app = app_service.create_app(tenant.id, app_params, account) # Verify app mode was set correctly assert app.mode == mode - assert app.name == app_args["name"] + assert app.name == app_params.name assert app.tenant_id == tenant.id assert app.created_by == account.id @@ -163,20 +164,20 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🎯", - "icon_background": "#45B7D1", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + + app_params = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🎯", + icon_background="#45B7D1", + ) app_service = AppService() - created_app = app_service.create_app(tenant.id, app_args, account) + created_app = app_service.create_app(tenant.id, app_params, account) # Get app using the service - needs current_user mock mock_current_user = create_autospec(Account, instance=True) @@ -211,31 +212,27 @@ class TestAppService: tenant = account.current_tenant # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppListParams, AppService, CreateAppParams app_service = AppService() # Create multiple apps app_names = [fake.company() for _ in range(5)] for name in app_names: - app_args = { - "name": name, - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "📱", - "icon_background": "#96CEB4", - } - app_service.create_app(tenant.id, app_args, account) + app_params = CreateAppParams( + name=name, + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="📱", + icon_background="#96CEB4", + ) + app_service.create_app(tenant.id, app_params, account) # Get paginated apps - args = { - "page": 1, - "limit": 10, - "mode": "chat", - } + params = AppListParams(page=1, limit=10, mode="chat") - paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params) # Verify pagination results assert paginated_apps is not None @@ -267,60 +264,47 @@ class TestAppService: tenant = account.current_tenant # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppListParams, AppService, CreateAppParams app_service = AppService() # Create apps with different modes - chat_app_args = { - "name": "Chat App", - "description": "A chat application", - "mode": "chat", - "icon_type": "emoji", - "icon": "💬", - "icon_background": "#FF6B6B", - } - completion_app_args = { - "name": "Completion App", - "description": "A completion application", - "mode": "completion", - "icon_type": "emoji", - "icon": "✍️", - "icon_background": "#4ECDC4", - } + chat_app_params = CreateAppParams( + name="Chat App", + description="A chat application", + mode="chat", + icon_type="emoji", + icon="💬", + icon_background="#FF6B6B", + ) + completion_app_params = CreateAppParams( + name="Completion App", + description="A completion application", + mode="completion", + icon_type="emoji", + icon="✍️", + icon_background="#4ECDC4", + ) - chat_app = app_service.create_app(tenant.id, chat_app_args, account) - completion_app = app_service.create_app(tenant.id, completion_app_args, account) + chat_app = app_service.create_app(tenant.id, chat_app_params, account) + completion_app = app_service.create_app(tenant.id, completion_app_params, account) # Test filter by mode - chat_args = { - "page": 1, - "limit": 10, - "mode": "chat", - } - chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args) + chat_apps = app_service.get_paginate_apps(account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat")) assert len(chat_apps.items) == 1 assert chat_apps.items[0].mode == "chat" # Test filter by name - name_args = { - "page": 1, - "limit": 10, - "mode": "chat", - "name": "Chat", - } - filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args) + filtered_apps = app_service.get_paginate_apps( + account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat") + ) assert len(filtered_apps.items) == 1 assert "Chat" in filtered_apps.items[0].name # Test filter by created_by_me - created_by_me_args = { - "page": 1, - "limit": 10, - "mode": "completion", - "is_created_by_me": True, - } - my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args) + my_apps = app_service.get_paginate_apps( + account.id, tenant.id, AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True) + ) assert len(my_apps.items) == 1 def test_get_paginate_apps_with_tag_filters( @@ -342,34 +326,29 @@ class TestAppService: tenant = account.current_tenant # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppListParams, AppService, CreateAppParams app_service = AppService() # Create an app - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🏷️", - "icon_background": "#FFEAA7", - } - app = app_service.create_app(tenant.id, app_args, account) + app_params = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🏷️", + icon_background="#FFEAA7", + ) + app = app_service.create_app(tenant.id, app_params, account) # Mock TagService to return the app ID for tag filtering with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service: mock_tag_service.return_value = [app.id] # Test with tag filter - args = { - "page": 1, - "limit": 10, - "mode": "chat", - "tag_ids": ["tag1", "tag2"], - } + params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["tag1", "tag2"]) - paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params) # Verify tag service was called mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"]) @@ -383,14 +362,9 @@ class TestAppService: with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service: mock_tag_service.return_value = [] - args = { - "page": 1, - "limit": 10, - "mode": "chat", - "tag_ids": ["nonexistent_tag"], - } + params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["nonexistent_tag"]) - paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params) # Should return None when no apps match tag filter assert paginated_apps is None @@ -412,20 +386,20 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🎯", - "icon_background": "#45B7D1", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + + app_params = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🎯", + icon_background="#45B7D1", + ) app_service = AppService() - app = app_service.create_app(tenant.id, app_args, account) + app = app_service.create_app(tenant.id, app_params, account) # Store original values original_name = app.name @@ -481,19 +455,19 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams app_service = AppService() app = app_service.create_app( tenant.id, - { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🎯", - "icon_background": "#45B7D1", - }, + CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🎯", + icon_background="#45B7D1", + ), account, ) @@ -533,19 +507,19 @@ class TestAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams app_service = AppService() app = app_service.create_app( tenant.id, - { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🎯", - "icon_background": "#45B7D1", - }, + CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🎯", + icon_background="#45B7D1", + ), account, ) @@ -584,20 +558,20 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🎯", - "icon_background": "#45B7D1", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + + app_params = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🎯", + icon_background="#45B7D1", + ) app_service = AppService() - app = app_service.create_app(tenant.id, app_args, account) + app = app_service.create_app(tenant.id, app_params, account) # Store original name original_name = app.name @@ -637,20 +611,20 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🎯", - "icon_background": "#45B7D1", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + + app_params = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🎯", + icon_background="#45B7D1", + ) app_service = AppService() - app = app_service.create_app(tenant.id, app_args, account) + app = app_service.create_app(tenant.id, app_params, account) # Store original values original_icon = app.icon @@ -698,18 +672,17 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🌐", - "icon_background": "#74B9FF", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🌐", + icon_background="#74B9FF", + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -758,18 +731,17 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🔌", - "icon_background": "#A29BFE", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🔌", + icon_background="#A29BFE", + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -818,18 +790,17 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🔄", - "icon_background": "#FD79A8", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🔄", + icon_background="#FD79A8", + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -869,18 +840,17 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🗑️", - "icon_background": "#E17055", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🗑️", + icon_background="#E17055", + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -921,18 +891,17 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🧹", - "icon_background": "#00B894", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🧹", + icon_background="#00B894", + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -981,18 +950,17 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "📊", - "icon_background": "#6C5CE7", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="📊", + icon_background="#6C5CE7", + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -1020,18 +988,17 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🔗", - "icon_background": "#FDCB6E", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🔗", + icon_background="#FDCB6E", + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -1060,18 +1027,17 @@ class TestAppService: tenant = account.current_tenant # Create app first - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🆔", - "icon_background": "#E84393", - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🆔", + icon_background="#E84393", + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -1107,26 +1073,20 @@ class TestAppService: password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) - tenant = account.current_tenant - - # Setup app creation arguments with invalid mode - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "invalid_mode", # Invalid mode - "icon_type": "emoji", - "icon": "❌", - "icon_background": "#D63031", - } # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import CreateAppParams - app_service = AppService() - - # Attempt to create app with invalid mode - with pytest.raises(ValueError, match="invalid mode value"): - app_service.create_app(tenant.id, app_args, account) + # Attempt to create app with invalid mode - Pydantic will reject invalid literal + with pytest.raises(ValidationError): + CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="invalid_mode", # type: ignore[arg-type] + icon_type="emoji", + icon="❌", + icon_background="#D63031", + ) def test_get_apps_with_special_characters_in_name( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1152,99 +1112,103 @@ class TestAppService: tenant = account.current_tenant # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppListParams, AppService, CreateAppParams app_service = AppService() # Create apps with special characters in names app_with_percent = app_service.create_app( tenant.id, - { - "name": "App with 50% discount", - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - }, + CreateAppParams( + name="App with 50% discount", + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ), account, ) app_with_underscore = app_service.create_app( tenant.id, - { - "name": "test_data_app", - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - }, + CreateAppParams( + name="test_data_app", + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ), account, ) app_with_backslash = app_service.create_app( tenant.id, - { - "name": "path\\to\\app", - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - }, + CreateAppParams( + name="path\\to\\app", + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ), account, ) # Create app that should NOT match app_no_match = app_service.create_app( tenant.id, - { - "name": "100% different", - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - }, + CreateAppParams( + name="100% different", + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ), account, ) # Test 1: Search with % character - args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} - paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + paginated_apps = app_service.get_paginate_apps( + account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10) + ) assert paginated_apps is not None assert paginated_apps.total == 1 assert len(paginated_apps.items) == 1 assert paginated_apps.items[0].name == "App with 50% discount" # Test 2: Search with _ character - args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10} - paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + paginated_apps = app_service.get_paginate_apps( + account.id, tenant.id, AppListParams(name="test_data", mode="chat", page=1, limit=10) + ) assert paginated_apps is not None assert paginated_apps.total == 1 assert len(paginated_apps.items) == 1 assert paginated_apps.items[0].name == "test_data_app" # Test 3: Search with \ character - args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10} - paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + paginated_apps = app_service.get_paginate_apps( + account.id, tenant.id, AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10) + ) assert paginated_apps is not None assert paginated_apps.total == 1 assert len(paginated_apps.items) == 1 assert paginated_apps.items[0].name == "path\\to\\app" # Test 4: Search with % should NOT match 100% (verifies escaping works) - args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10} - paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args) + paginated_apps = app_service.get_paginate_apps( + account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10) + ) assert paginated_apps is not None assert paginated_apps.total == 1 assert all("50%" in app.name for app in paginated_apps.items) diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 09ba041244..07dc3a4e9e 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -90,16 +90,34 @@ class TestCreditPoolService: pool = CreditPoolService.get_pool(tenant_id=tenant_id) assert pool.quota_used == credits_required - def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session): + def test_check_and_deduct_credits_raises_without_deducting_when_insufficient( + self, db_session_with_containers: Session + ): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) remaining = 5 pool.quota_used = pool.quota_limit - remaining + quota_used = pool.quota_used db_session_with_containers.commit() - result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200) + with pytest.raises(QuotaExceededError, match="Insufficient credits remaining"): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200) + + db_session_with_containers.expire_all() + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool.quota_used == quota_used + + def test_deduct_credits_capped_depletes_available_balance(self, db_session_with_containers: Session): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + remaining = 5 + pool.quota_used = pool.quota_limit - remaining + quota_limit = pool.quota_limit + db_session_with_containers.commit() + + result = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=200) assert result == remaining db_session_with_containers.expire_all() updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) - assert updated_pool.quota_used == pool.quota_limit + assert updated_pool.quota_used == quota_limit diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 2f90d16176..0c610311bb 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -16,6 +16,7 @@ from uuid import uuid4 from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType +from models import AccountStatus, CreatorUserRole, TenantStatus from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -25,7 +26,7 @@ from models.dataset import ( DatasetProcessRule, DatasetQuery, ) -from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode +from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode, TagType from models.model import Tag, TagBinding from services.dataset_service import DatasetService, DocumentService @@ -42,11 +43,11 @@ class DatasetRetrievalTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) tenant = Tenant( name=f"tenant-{uuid4()}", - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add_all([account, tenant]) db_session_with_containers.flush() @@ -72,7 +73,7 @@ class DatasetRetrievalTestDataFactory: email=f"{uuid4()}@example.com", name=f"user-{uuid4()}", interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.flush() @@ -130,7 +131,7 @@ class DatasetRetrievalTestDataFactory: @staticmethod def create_process_rule( - db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict + db_session_with_containers: Session, dataset_id: str, created_by: str, mode: ProcessRuleMode, rules: dict ) -> DatasetProcessRule: """Create a dataset process rule.""" process_rule = DatasetProcessRule( @@ -153,7 +154,7 @@ class DatasetRetrievalTestDataFactory: content=content, source=DatasetQuerySource.APP, source_app_id=None, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=created_by, ) db_session_with_containers.add(dataset_query) @@ -176,7 +177,7 @@ class DatasetRetrievalTestDataFactory: """Create a knowledge tag and bind it to the target dataset.""" tag = Tag( tenant_id=tenant_id, - type="knowledge", + type=TagType.KNOWLEDGE, name=f"tag-{uuid4()}", created_by=created_by, ) diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index bdf6d9b951..6d0d281c6b 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from models.enums import ConversationFromSource, FeedbackRating, InvokeFrom from models.model import MessageFeedback -from services.app_service import AppService +from services.app_service import AppService, CreateAppParams from services.errors.message import ( FirstMessageNotExistsError, LastMessageNotExistsError, @@ -103,16 +103,16 @@ class TestMessageService: tenant = account.current_tenant # Setup app creation arguments - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "advanced-chat", # Use advanced-chat mode to use mocked workflow - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="advanced-chat", # Use advanced-chat mode to use mocked workflow, + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) # Create app app_service = AppService() diff --git a/api/tests/test_containers_integration_tests/services/test_ops_service.py b/api/tests/test_containers_integration_tests/services/test_ops_service.py index e2e1a228b2..ff76bce416 100644 --- a/api/tests/test_containers_integration_tests/services/test_ops_service.py +++ b/api/tests/test_containers_integration_tests/services/test_ops_service.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session from core.ops.entities.config_entity import TracingProviderEnum from models.model import TraceAppConfig from services.account_service import AccountService, TenantService -from services.app_service import AppService +from services.app_service import AppService, CreateAppParams from services.ops_service import OpsService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -57,14 +57,14 @@ class TestOpsService: app_service = AppService() app = app_service.create_app( tenant.id, - { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - }, + CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + ), account, ) return app, account diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 7b9e9924cd..7368ad4249 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -8,7 +8,7 @@ from models import App, CreatorUserRole from models.enums import ConversationFromSource from models.model import EndUser, Message from models.web import SavedMessage -from services.app_service import AppService +from services.app_service import AppService, CreateAppParams from services.saved_message_service import SavedMessageService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -73,16 +73,16 @@ class TestSavedMessageService: tenant = account.current_tenant # Create app with realistic data - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 797731d04b..8e53a2d6cd 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -11,7 +11,7 @@ from models.enums import ConversationFromSource from models.model import Conversation, EndUser from models.web import PinnedConversation from services.account_service import AccountService, TenantService -from services.app_service import AppService +from services.app_service import AppService, CreateAppParams from services.web_conversation_service import WebConversationService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -77,16 +77,16 @@ class TestWebConversationService: tenant = account.current_tenant # Create app with realistic data - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 6d5c7380b7..52b1229302 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -543,8 +543,8 @@ class TestWebhookService: "bad_file": MagicMock(filename="test.bad", content_type="text/plain"), } - files["good_file"].read.return_value = b"content" - files["bad_file"].read.side_effect = Exception("Read error") + files["good_file"].stream.read.return_value = b"content" + files["bad_file"].stream.read.side_effect = Exception("Read error") webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index a2cdddad61..07a49130d0 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -17,7 +17,7 @@ from models.workflow import WorkflowAppLogCreatedFrom from services.account_service import AccountService, TenantService # Delay import of AppService to avoid circular dependency -# from services.app_service import AppService +# from services.app_service import AppService, CreateAppParams from services.workflow_app_service import LogView, WorkflowAppService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -82,20 +82,20 @@ class TestWorkflowAppService: TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant - # Create app with realistic data - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "workflow", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + + # Create app with realistic data + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="workflow", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -146,20 +146,20 @@ class TestWorkflowAppService: """ fake = Faker() - # Create app with realistic data - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "workflow", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } - # Import here to avoid circular dependency - from services.app_service import AppService + from services.app_service import AppService, CreateAppParams + + # Create app with realistic data + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="workflow", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index d02a078281..09fe1570bc 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -13,7 +13,7 @@ from models.model import ( ) from models.workflow import WorkflowRun from services.account_service import AccountService, TenantService -from services.app_service import AppService +from services.app_service import AppService, CreateAppParams from services.workflow_run_service import WorkflowRunService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -79,16 +79,16 @@ class TestWorkflowRunService: tenant = account.current_tenant # Create app with realistic data - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "chat", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="chat", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) @@ -535,13 +535,13 @@ class TestWorkflowRunService: tenant = account.current_tenant # Create app - app_args = { - "name": "Test App", - "mode": "chat", - "icon_type": "emoji", - "icon": "🚀", - "icon_background": "#4ECDC4", - } + app_args = CreateAppParams( + name="Test App", + mode="chat", + icon_type="emoji", + icon="🚀", + icon_background="#4ECDC4", + ) app = app_service.create_app(tenant.id, app_args, account) # Create workflow run without node executions @@ -586,13 +586,13 @@ class TestWorkflowRunService: tenant = account.current_tenant # Create app - app_args = { - "name": "Test App", - "mode": "chat", - "icon_type": "emoji", - "icon": "🚀", - "icon_background": "#4ECDC4", - } + app_args = CreateAppParams( + name="Test App", + mode="chat", + icon_type="emoji", + icon="🚀", + icon_background="#4ECDC4", + ) app = app_service.create_app(tenant.id, app_args, account) # Use invalid workflow run ID @@ -637,13 +637,13 @@ class TestWorkflowRunService: tenant = account.current_tenant # Create app - app_args = { - "name": "Test App", - "mode": "chat", - "icon_type": "emoji", - "icon": "🚀", - "icon_background": "#4ECDC4", - } + app_args = CreateAppParams( + name="Test App", + mode="chat", + icon_type="emoji", + icon="🚀", + icon_background="#4ECDC4", + ) app = app_service.create_app(tenant.id, app_args, account) # Create workflow run diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 21a1975879..9b574fe2df 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -11,7 +11,7 @@ from core.tools.errors import WorkflowToolHumanInputNotSupportedError from models.tools import WorkflowToolProvider from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService -from services.app_service import AppService +from services.app_service import AppService, CreateAppParams from services.tools.workflow_tools_manage_service import WorkflowToolManageService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -94,16 +94,16 @@ class TestWorkflowToolManageService: tenant = account.current_tenant # Create app with realistic data - app_args = { - "name": fake.company(), - "description": fake.text(max_nb_chars=100), - "mode": "workflow", - "icon_type": "emoji", - "icon": "🤖", - "icon_background": "#FF6B6B", - "api_rph": 100, - "api_rpm": 10, - } + app_args = CreateAppParams( + name=fake.company(), + description=fake.text(max_nb_chars=100), + mode="workflow", + icon_type="emoji", + icon="🤖", + icon_background="#FF6B6B", + api_rph=100, + api_rpm=10, + ) app_service = AppService() app = app_service.create_app(tenant.id, app_args, account) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 6bfb1e1f1e..6a95bfc425 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -7,6 +7,7 @@ The task is responsible for removing document segments from the search index whe """ from unittest.mock import MagicMock, patch +from uuid import uuid4 from faker import Faker from sqlalchemy import select @@ -82,7 +83,7 @@ class TestDisableSegmentsFromIndexTask: return account - def _create_test_dataset(self, db_session_with_containers: Session, account, fake: Faker | None = None): + def _create_test_dataset(self, db_session_with_containers: Session, account: Account, fake: Faker | None = None): """ Helper method to create a test dataset with realistic data. @@ -117,7 +118,7 @@ class TestDisableSegmentsFromIndexTask: return dataset def _create_test_document( - self, db_session_with_containers: Session, dataset, account: Account, fake: Faker | None = None + self, db_session_with_containers: Session, dataset: Dataset, account: Account, fake: Faker | None = None ): """ Helper method to create a test document with realistic data. @@ -164,7 +165,7 @@ class TestDisableSegmentsFromIndexTask: return document def _create_test_segments( - self, db_session_with_containers: Session, document, dataset, account, count=3, fake=None + self, db_session_with_containers: Session, document, dataset: Dataset, account: Account, count=3, fake=None ): """ Helper method to create test document segments with realistic data. @@ -217,7 +218,9 @@ class TestDisableSegmentsFromIndexTask: return segments - def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake: Faker | None = None): + def _create_dataset_process_rule( + self, db_session_with_containers: Session, dataset: Dataset, fake: Faker | None = None + ): """ Helper method to create a dataset process rule. @@ -230,21 +233,19 @@ class TestDisableSegmentsFromIndexTask: DatasetProcessRule: Created process rule instance """ fake = fake or Faker() - process_rule = DatasetProcessRule() - process_rule.id = fake.uuid4() - process_rule.tenant_id = dataset.tenant_id - process_rule.dataset_id = dataset.id - process_rule.mode = ProcessRuleMode.AUTOMATIC - process_rule.rules = ( - "{" - '"mode": "automatic", ' - '"rules": {' - '"pre_processing_rules": [], "segmentation": ' - '{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}' - "}" + process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=ProcessRuleMode.AUTOMATIC, + rules=( + "{" + '"mode": "automatic", ' + '"rules": {' + '"pre_processing_rules": [], "segmentation": ' + '{"separator": "\\n\\n", "max_tokens": 1000, "chunk_overlap": 50}}' + "}" + ), + created_by=str(uuid4()), ) - process_rule.created_by = dataset.created_by - process_rule.updated_by = dataset.updated_by db_session_with_containers.add(process_rule) db_session_with_containers.commit() diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 57dbf453de..919ebbc656 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -8,6 +8,47 @@ from yarl import URL from configs.app_config import DifyConfig +def _set_basic_config_env(monkeypatch: pytest.MonkeyPatch) -> None: + os.environ.clear() + monkeypatch.setenv("CONSOLE_API_URL", "https://example.com") + monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com") + monkeypatch.setenv("DB_TYPE", "postgresql") + monkeypatch.setenv("DB_USERNAME", "postgres") + monkeypatch.setenv("DB_PASSWORD", "postgres") + monkeypatch.setenv("DB_HOST", "localhost") + monkeypatch.setenv("DB_PORT", "5432") + monkeypatch.setenv("DB_DATABASE", "dify") + + +def test_dify_config_keeps_secret_key_empty_when_missing( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + _set_basic_config_env(monkeypatch) + monkeypatch.delenv("SECRET_KEY", raising=False) + monkeypatch.setenv("OPENDAL_FS_ROOT", str(tmp_path)) + + config = DifyConfig(_env_file=None) + + assert config.SECRET_KEY == "" + assert not hasattr(config, "OPENDAL_FS_ROOT") + assert not (tmp_path / ".dify_secret_key").exists() + + +def test_dify_config_preserves_explicit_secret_key( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + _set_basic_config_env(monkeypatch) + monkeypatch.setenv("SECRET_KEY", "explicit") + monkeypatch.setenv("OPENDAL_FS_ROOT", str(tmp_path)) + + config = DifyConfig(_env_file=None) + + assert config.SECRET_KEY == "explicit" + assert not (tmp_path / ".dify_secret_key").exists() + + def test_dify_config(monkeypatch: pytest.MonkeyPatch): # clear system environment variables os.environ.clear() diff --git a/api/tests/unit_tests/controllers/common/test_schema.py b/api/tests/unit_tests/controllers/common/test_schema.py index 575f8c839c..7cabafba0e 100644 --- a/api/tests/unit_tests/controllers/common/test_schema.py +++ b/api/tests/unit_tests/controllers/common/test_schema.py @@ -47,6 +47,10 @@ class QueryModel(BaseModel): ambiguous: int | str | None = Field(default=None, description="Ambiguous query parameter") +class ResponseAliasModel(BaseModel): + public_name: str = Field(validation_alias="internal_name") + + @pytest.fixture(autouse=True) def mock_console_ns(): """Mock the console_ns to avoid circular imports during test collection.""" @@ -146,6 +150,20 @@ def test_register_schema_models_calls_register_schema_model(monkeypatch: pytest. ] +def test_register_response_schema_model_uses_serialized_field_names(): + from controllers.common.schema import register_response_schema_model + + namespace = MagicMock(spec=Namespace) + + register_response_schema_model(namespace, ResponseAliasModel) + + model_name, schema = namespace.schema_model.call_args.args + + assert model_name == "ResponseAliasModel" + assert "public_name" in schema["properties"] + assert "internal_name" not in schema["properties"] + + def test_get_or_create_model_returns_existing_model(mock_console_ns): from controllers.common.schema import get_or_create_model diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index d06978c3fa..1c4e5864eb 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -111,3 +111,24 @@ def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPa with pytest.raises(NotFoundError): with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") + + +def test_pause_details_returns_empty_response_for_non_paused_run(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + account = _make_account() + _patch_console_guards(monkeypatch, account) + + workflow_run = Mock(spec=WorkflowRun) + workflow_run.tenant_id = "tenant-123" + workflow_run.status = WorkflowExecutionStatus.RUNNING + fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run)) + monkeypatch.setattr(workflow_run_module, "db", fake_db) + + with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): + response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") + + assert status == 200 + assert response == {"paused_at": None, "paused_nodes": []} + + +def test_pause_details_response_schema_is_registered() -> None: + assert workflow_run_module.WorkflowPauseDetailsResponse.__name__ in workflow_run_module.console_ns.models diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py new file mode 100644 index 0000000000..e225e31563 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any + +import pytest +from flask import Flask +from flask_restx import marshal + +from controllers.console.app import workflow_run as workflow_run_module + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def _serialize_200_response(handler, payload: Any) -> Any: + response_doc = getattr(handler, "__apidoc__", {}).get("responses", {}).get("200") + if response_doc is None: + return payload + + response_model = response_doc[1] + if isinstance(response_model, dict): + return marshal(payload, response_model) + return payload + + +def _account() -> SimpleNamespace: + return SimpleNamespace(id="account-1", name="Alice", email="alice@example.com") + + +def _workflow_run_summary(**overrides) -> SimpleNamespace: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + payload = { + "id": "run-1", + "version": "v1", + "status": "succeeded", + "elapsed_time": 1.5, + "total_tokens": 10, + "total_steps": 2, + "created_by_account": _account(), + "created_at": created_at, + "finished_at": created_at, + "exceptions_count": 0, + "retry_index": 0, + } + payload.update(overrides) + return SimpleNamespace(**payload) + + +def _workflow_run_node_execution(**overrides) -> SimpleNamespace: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + payload = { + "id": "node-exec-1", + "index": 1, + "predecessor_node_id": None, + "node_id": "node-1", + "node_type": "start", + "title": "Start", + "inputs_dict": {"query": "hello"}, + "process_data_dict": {"step": "prepared"}, + "outputs_dict": {"answer": "world"}, + "status": "succeeded", + "error": None, + "elapsed_time": 1.0, + "execution_metadata_dict": {"total_tokens": 3}, + "extras": {}, + "created_at": created_at, + "created_by_role": "account", + "created_by_account": _account(), + "created_by_end_user": None, + "finished_at": created_at, + "inputs_truncated": False, + "outputs_truncated": False, + "process_data_truncated": False, + } + payload.update(overrides) + return SimpleNamespace(**payload) + + +def test_workflow_run_list_returns_frontend_history_contract(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + class WorkflowRunService: + def get_paginate_workflow_runs(self, **_kwargs): + return { + "limit": 10, + "has_more": False, + "data": [_workflow_run_summary()], + } + + monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) + + api = workflow_run_module.WorkflowRunListApi() + handler = _unwrap(api.get) + + with app.test_request_context("/apps/app-1/workflow-runs?limit=10", method="GET"): + payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) + + response = _serialize_200_response(api.get, payload) + + assert response["limit"] == 10 + assert response["has_more"] is False + assert response["data"][0] == { + "id": "run-1", + "version": "v1", + "status": "succeeded", + "elapsed_time": 1.5, + "total_tokens": 10, + "total_steps": 2, + "created_by_account": {"id": "account-1", "name": "Alice", "email": "alice@example.com"}, + "created_at": 1767323045, + "finished_at": 1767323045, + "exceptions_count": 0, + "retry_index": 0, + } + + +def test_advanced_chat_workflow_run_list_keeps_message_fields(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + class WorkflowRunService: + def get_paginate_advanced_chat_workflow_runs(self, **_kwargs): + return { + "limit": 1, + "has_more": True, + "data": [ + _workflow_run_summary( + conversation_id="conversation-1", + message_id="message-1", + ) + ], + } + + monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) + + api = workflow_run_module.AdvancedChatAppWorkflowRunListApi() + handler = _unwrap(api.get) + + with app.test_request_context("/apps/app-1/advanced-chat/workflow-runs?limit=1", method="GET"): + payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) + + response = _serialize_200_response(api.get, payload) + + assert response["data"][0]["conversation_id"] == "conversation-1" + assert response["data"][0]["message_id"] == "message-1" + + +def test_workflow_run_detail_returns_frontend_detail_contract(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: + created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) + workflow_run = SimpleNamespace( + id="run-1", + version="v1", + graph_dict={"nodes": []}, + inputs_dict={"query": "hello"}, + status="succeeded", + outputs_dict={"answer": "world"}, + error=None, + elapsed_time=1.5, + total_tokens=10, + total_steps=2, + created_by_role="account", + created_by_account=_account(), + created_by_end_user=None, + created_at=created_at, + finished_at=created_at, + exceptions_count=0, + ) + + class WorkflowRunService: + def get_workflow_run(self, **_kwargs): + return workflow_run + + monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) + + api = workflow_run_module.WorkflowRunDetailApi() + handler = _unwrap(api.get) + + with app.test_request_context("/apps/app-1/workflow-runs/run-1", method="GET"): + payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1") + + response = _serialize_200_response(api.get, payload) + + assert response == { + "id": "run-1", + "version": "v1", + "graph": {"nodes": []}, + "inputs": {"query": "hello"}, + "status": "succeeded", + "outputs": {"answer": "world"}, + "error": None, + "elapsed_time": 1.5, + "total_tokens": 10, + "total_steps": 2, + "created_by_role": "account", + "created_by_account": {"id": "account-1", "name": "Alice", "email": "alice@example.com"}, + "created_by_end_user": None, + "created_at": 1767323045, + "finished_at": 1767323045, + "exceptions_count": 0, + } + + +def test_workflow_run_node_executions_return_frontend_trace_contract( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: + class WorkflowRunService: + def get_workflow_run_node_executions(self, **_kwargs): + return [_workflow_run_node_execution()] + + monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) + monkeypatch.setattr(workflow_run_module, "current_user", SimpleNamespace(id="account-1")) + + api = workflow_run_module.WorkflowRunNodeExecutionListApi() + handler = _unwrap(api.get) + + with app.test_request_context("/apps/app-1/workflow-runs/run-1/node-executions", method="GET"): + payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1") + + response = _serialize_200_response(api.get, payload) + + assert response == { + "data": [ + { + "id": "node-exec-1", + "index": 1, + "predecessor_node_id": None, + "node_id": "node-1", + "node_type": "start", + "title": "Start", + "inputs": {"query": "hello"}, + "process_data": {"step": "prepared"}, + "outputs": {"answer": "world"}, + "status": "succeeded", + "error": None, + "elapsed_time": 1.0, + "execution_metadata": {"total_tokens": 3}, + "extras": {}, + "created_at": 1767323045, + "created_by_role": "account", + "created_by_account": {"id": "account-1", "name": "Alice", "email": "alice@example.com"}, + "created_by_end_user": None, + "finished_at": 1767323045, + "inputs_truncated": False, + "outputs_truncated": False, + "process_data_truncated": False, + } + ] + } diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 14f00e6295..82a063307b 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -88,6 +88,11 @@ def valid_parameters(): } +def test_trial_workflow_uses_trial_scoped_simple_account_model(): + assert module.simple_account_model.name == "TrialSimpleAccount" + assert hasattr(module.simple_account_model, "items") + + class TestTrialAppWorkflowRunApi: def test_not_workflow_app(self, app: Flask): api = module.TrialAppWorkflowRunApi() diff --git a/api/tests/unit_tests/controllers/files/test_upload.py b/api/tests/unit_tests/controllers/files/test_upload.py index e8f3cd4b66..ff6ba0e9a1 100644 --- a/api/tests/unit_tests/controllers/files/test_upload.py +++ b/api/tests/unit_tests/controllers/files/test_upload.py @@ -1,3 +1,4 @@ +import io import types from unittest.mock import patch @@ -30,9 +31,10 @@ class DummyFile: self.filename = filename self.mimetype = mimetype self._content = content + self.stream = io.BytesIO(content) def read(self): - return self._content + return self.stream.read() class DummyToolFile: diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 230c51161f..738238d10a 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -1057,8 +1057,8 @@ class TestDocumentAddByTextApi: """Test error when both dataset and payload lack indexing_technique. When ``indexing_technique`` is ``None`` in the payload, ``model_dump(exclude_none=True)`` - omits the key. The production code accesses ``args["indexing_technique"]`` which raises - ``KeyError`` before the ``ValueError`` guard can fire. + omits the key. The service API should still raise the same validation error as other + document creation paths instead of leaking a ``KeyError`` from the dumped payload dict. """ # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) @@ -1074,7 +1074,7 @@ class TestDocumentAddByTextApi: headers={"Authorization": "Bearer test_token"}, ): api = DocumentAddByTextApi() - with pytest.raises(KeyError): + with pytest.raises(ValueError, match="indexing_technique is required."): api.post(tenant_id=mock_tenant.id, dataset_id=mock_dataset.id) diff --git a/api/tests/unit_tests/controllers/test_swagger.py b/api/tests/unit_tests/controllers/test_swagger.py new file mode 100644 index 0000000000..999f1ae78d --- /dev/null +++ b/api/tests/unit_tests/controllers/test_swagger.py @@ -0,0 +1,72 @@ +"""Swagger JSON rendering tests for Flask-RESTX API blueprints.""" + +import pytest +from flask import Flask + + +def _definition_refs(value: object) -> set[str]: + refs: set[str] = set() + if isinstance(value, dict): + ref = value.get("$ref") + if isinstance(ref, str) and ref.startswith("#/definitions/"): + refs.add(ref.removeprefix("#/definitions/")) + for item in value.values(): + refs.update(_definition_refs(item)) + elif isinstance(value, list): + for item in value: + refs.update(_definition_refs(item)) + return refs + + +@pytest.mark.parametrize( + ("first_kwargs", "second_kwargs"), + [ + ({"min_items": 1}, {"min_items": 2}), + ({"max_items": 1}, {"max_items": 2}), + ({"unique": True}, {"unique": False}), + ], +) +def test_inline_model_name_includes_list_constraints( + first_kwargs: dict[str, object], + second_kwargs: dict[str, object], +): + from flask_restx import fields + + from libs.flask_restx_compat import _inline_model_name + + first_inline_model: dict[object, object] = {"items": fields.List(fields.String, **first_kwargs)} + second_inline_model: dict[object, object] = {"items": fields.List(fields.String, **second_kwargs)} + + assert _inline_model_name(first_inline_model) != _inline_model_name(second_inline_model) + + +def test_swagger_json_endpoints_render(monkeypatch: pytest.MonkeyPatch): + from configs import dify_config + from controllers.console import bp as console_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + monkeypatch.setattr(dify_config, "SWAGGER_UI_ENABLED", True) + + app = Flask(__name__) + app.config["TESTING"] = True + app.config["RESTX_INCLUDE_ALL_MODELS"] = True + app.register_blueprint(console_bp) + app.register_blueprint(web_bp) + app.register_blueprint(service_api_bp) + + client = app.test_client() + + for route in ("/console/api/swagger.json", "/api/swagger.json", "/v1/swagger.json"): + response = client.get(route) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["swagger"] == "2.0" + assert "paths" in payload + assert "definitions" in payload + assert isinstance(payload["definitions"], dict) + missing_refs = _definition_refs(payload) - set(payload["definitions"]) + assert not sorted(ref for ref in missing_refs if ref.startswith("_AnonymousInlineModel")) + + assert app.config["RESTX_INCLUDE_ALL_MODELS"] is True diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py index 3fd21ab22b..62e1d22129 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_prompt_template_manager.py @@ -1,3 +1,4 @@ +from collections import UserString from unittest.mock import MagicMock import pytest @@ -12,21 +13,25 @@ from core.app.app_config.easy_ui_based_app.prompt_template.manager import ( # ----------------------------- -class DummyEnumValue: +class DummyEnumValue(UserString): def __init__(self, value): + super().__init__(value) self.value = value class DummyPromptType: def __init__(self): - self.SIMPLE = "simple" - self.ADVANCED = "advanced" + self.SIMPLE = DummyEnumValue("simple") + self.ADVANCED = DummyEnumValue("advanced") def value_of(self, value): - return value + for enum_value in self: + if enum_value.value == value: + return enum_value + raise ValueError(f"invalid prompt type value {value}") def __iter__(self): - return iter([DummyEnumValue("simple"), DummyEnumValue("advanced")]) + return iter([self.SIMPLE, self.ADVANCED]) # ----------------------------- diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index e5cb8a3383..c26011b758 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -132,7 +132,9 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline._task_state.answer = "partial answer" pipeline._workflow_run_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool( + variables=build_system_variables(workflow_execution_id="run-id"), + ), start_at=0.0, total_tokens=7, node_run_steps=3, @@ -372,7 +374,9 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_run_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -583,7 +587,9 @@ class TestAdvancedChatGenerateTaskPipeline: self.items = items graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) @@ -617,7 +623,9 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_message_end_event_applies_output_moderation(self, monkeypatch: pytest.MonkeyPatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index aa71f4d9c4..1acebfee17 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -60,7 +60,7 @@ class _StubToolNode(Node[_StubToolNodeData]): def __init__( self, node_id: str, - config: _StubToolNodeData, + data: _StubToolNodeData, *, graph_init_params, graph_runtime_state, @@ -68,7 +68,7 @@ class _StubToolNode(Node[_StubToolNodeData]): ) -> None: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) @@ -169,7 +169,7 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py index 0e9f8b6f35..2e4e469eb5 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_generator.py @@ -1,3 +1,4 @@ +import contextlib from types import SimpleNamespace from unittest.mock import MagicMock @@ -24,6 +25,76 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false(): assert WorkflowAppGenerator()._should_prepare_user_inputs(args) +def test_generate_includes_parent_trace_context_in_extras(monkeypatch): + generator = WorkflowAppGenerator() + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppGenerator._bind_file_access_scope", + lambda *args, **kwargs: contextlib.nullcontext(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config", + lambda *args, **kwargs: SimpleNamespace( + app_id="app-1", tenant_id="tenant-1", workflow_id="workflow-1", variables=[] + ), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.file_factory.build_from_mappings", lambda *args, **kwargs: [] + ) + monkeypatch.setattr("core.app.apps.workflow.app_generator.TraceQueueManager", MagicMock()) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + MagicMock(return_value=MagicMock()), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + MagicMock(return_value=MagicMock()), + ) + monkeypatch.setattr("core.app.apps.workflow.app_generator.db", SimpleNamespace(engine=MagicMock())) + monkeypatch.setattr(generator, "_prepare_user_inputs", lambda *, user_inputs, **kwargs: user_inputs) + + captured = {} + + def fake_workflow_app_generate_entity(**kwargs): + captured["workflow_app_generate_entity_kwargs"] = kwargs + return SimpleNamespace(**kwargs) + + def fake_generate(**kwargs): + captured["application_generate_entity"] = kwargs["application_generate_entity"] + return {"data": {}} + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateEntity", fake_workflow_app_generate_entity + ) + monkeypatch.setattr(generator, "_generate", fake_generate) + + result = generator.generate( + app_model=SimpleNamespace(tenant_id="tenant-1", id="app-1"), + workflow=SimpleNamespace(features_dict={}), + user=SimpleNamespace(id="user-1", session_id="session-1"), + args={ + "inputs": {"query": "hello"}, + "files": [], + "external_trace_id": "trace-1", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + invoke_from="service-api", + streaming=False, + call_depth=0, + ) + + assert result == {"data": {}} + extras = captured["workflow_app_generate_entity_kwargs"]["extras"] + assert extras["external_trace_id"] == "trace-1" + assert extras["parent_trace_context"].model_dump() == { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + } + + def test_resume_delegates_to_generate(mocker: MockerFixture): generator = WorkflowAppGenerator() mock_generate = mocker.patch.object(generator, "_generate", return_value="ok") diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index dbe846cbc5..2ec2e38465 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -57,7 +57,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) @@ -96,7 +96,7 @@ class TestWorkflowBasedAppRunner: def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch: pytest.MonkeyPatch): runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) @@ -167,7 +167,7 @@ class TestWorkflowBasedAppRunner: app_id="app", ) graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) @@ -246,7 +246,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) graph_runtime_state.register_paused_node("node-1") @@ -289,7 +289,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) @@ -460,7 +460,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), start_at=0.0, ) workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 3afdbd2822..bc167d45e8 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -17,7 +17,7 @@ from models.workflow import Workflow def _make_graph_state(): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, environment_variables=[], diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index 1311d5e9cb..ea21a1cc1a 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -95,7 +95,9 @@ class TestWorkflowGenerateTaskPipeline: def test_to_blocking_response_falls_back_to_human_input_required_when_pause_event_missing(self): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool( + variables=build_system_variables(workflow_execution_id="run-id"), + ), start_at=0.0, total_tokens=5, node_run_steps=2, @@ -283,7 +285,9 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -725,7 +729,9 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) @@ -753,7 +759,9 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) @@ -769,7 +777,9 @@ class TestWorkflowGenerateTaskPipeline: def test_process_stream_response_main_match_paths_and_cleanup(self): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-id") + ), start_at=0.0, ) pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index d3bd15b6f3..320a3bc42c 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -21,7 +21,9 @@ class TestTriggerPostLayer: ) runtime_state = SimpleNamespace( outputs={"answer": "ok"}, - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-1") + ), total_tokens=12, ) @@ -60,7 +62,9 @@ class TestTriggerPostLayer: def test_on_event_handles_missing_trigger_log(self): runtime_state = SimpleNamespace( outputs={}, - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-1") + ), total_tokens=0, ) @@ -91,7 +95,9 @@ class TestTriggerPostLayer: def test_on_event_ignores_non_status_events(self): runtime_state = SimpleNamespace( outputs={}, - variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(workflow_execution_id="run-1") + ), total_tokens=0, ) diff --git a/api/tests/unit_tests/core/app/test_llm_quota.py b/api/tests/unit_tests/core/app/test_llm_quota.py new file mode 100644 index 0000000000..d9390a4a8f --- /dev/null +++ b/api/tests/unit_tests/core/app/test_llm_quota.py @@ -0,0 +1,617 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import create_engine, select + +from configs import dify_config +from core.app.llm.quota import ( + deduct_llm_quota, + deduct_llm_quota_for_model, + ensure_llm_quota_available, + ensure_llm_quota_available_for_model, +) +from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit +from core.errors.error import QuotaExceededError +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType +from models import TenantCreditPool +from models.enums import ProviderQuotaType as ModelProviderQuotaType +from models.provider import Provider, ProviderType + + +def test_ensure_llm_quota_available_for_model_raises_when_system_model_is_exhausted() -> None: + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + get_provider_model=MagicMock(return_value=SimpleNamespace(status=ModelStatus.QUOTA_EXCEEDED)), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."), + ): + ensure_llm_quota_available_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + ) + + provider_configuration.get_provider_model.assert_called_once_with( + model_type=ModelType.LLM, + model="gpt-4o", + ) + + +def test_ensure_llm_quota_available_for_model_raises_when_provider_is_missing() -> None: + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = None + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + pytest.raises(ValueError, match="Provider openai does not exist."), + ): + ensure_llm_quota_available_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + ) + + +def test_ensure_llm_quota_available_for_model_ignores_custom_provider_configuration() -> None: + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.CUSTOM, + get_provider_model=MagicMock(), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager): + ensure_llm_quota_available_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + ) + + provider_configuration.get_provider_model.assert_not_called() + + +def test_deduct_llm_quota_for_model_uses_identity_based_trial_billing() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 42 + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + mock_deduct_credits.assert_called_once_with( + tenant_id="tenant-id", + credits_required=42, + ) + + +def test_deduct_llm_quota_for_model_caps_trial_pool_when_usage_exceeds_remaining() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 3 + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + engine = create_engine("sqlite:///:memory:") + TenantCreditPool.__table__.create(engine) + with engine.begin() as connection: + connection.execute( + TenantCreditPool.__table__.insert(), + { + "id": "trial-pool", + "tenant_id": "tenant-id", + "pool_type": ModelProviderQuotaType.TRIAL, + "quota_limit": 10, + "quota_used": 9, + }, + ) + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + with engine.connect() as connection: + quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == "trial-pool")) + + assert quota_used == 10 + + +def test_deduct_llm_quota_for_model_returns_for_unbounded_quota() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 42 + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=-1, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + mock_deduct_credits.assert_not_called() + + +def test_deduct_llm_quota_for_model_uses_credit_configuration() -> None: + usage = LLMUsage.empty_usage() + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.CREDITS, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch.object(type(dify_config), "get_model_credits", return_value=9) as mock_get_model_credits, + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + mock_get_model_credits.assert_called_once_with("gpt-4o") + mock_deduct_credits.assert_called_once_with( + tenant_id="tenant-id", + credits_required=9, + ) + + +def test_deduct_llm_quota_for_model_uses_single_charge_for_times_quota() -> None: + usage = LLMUsage.empty_usage() + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TIMES, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + mock_deduct_credits.assert_called_once_with( + tenant_id="tenant-id", + credits_required=1, + ) + + +def test_deduct_llm_quota_for_model_uses_paid_billing_pool() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 5 + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.PAID, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.PAID, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + mock_deduct_credits.assert_called_once_with( + tenant_id="tenant-id", + credits_required=5, + pool_type="paid", + ) + + +def test_deduct_llm_quota_for_model_updates_free_quota_usage() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 3 + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.FREE, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.FREE, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + engine = create_engine("sqlite:///:memory:") + Provider.__table__.create(engine) + with engine.begin() as connection: + connection.execute( + Provider.__table__.insert(), + [ + { + "id": "matching-provider", + "tenant_id": "tenant-id", + "provider_name": "openai", + "provider_type": ProviderType.SYSTEM, + "quota_type": ProviderQuotaType.FREE, + "quota_limit": 100, + "quota_used": 10, + "is_valid": True, + }, + { + "id": "other-tenant", + "tenant_id": "other-tenant-id", + "provider_name": "openai", + "provider_type": ProviderType.SYSTEM, + "quota_type": ProviderQuotaType.FREE, + "quota_limit": 100, + "quota_used": 20, + "is_valid": True, + }, + { + "id": "other-provider", + "tenant_id": "tenant-id", + "provider_name": "anthropic", + "provider_type": ProviderType.SYSTEM, + "quota_type": ProviderQuotaType.FREE, + "quota_limit": 100, + "quota_used": 30, + "is_valid": True, + }, + { + "id": "custom-provider", + "tenant_id": "tenant-id", + "provider_name": "openai", + "provider_type": ProviderType.CUSTOM, + "quota_type": ProviderQuotaType.FREE, + "quota_limit": 100, + "quota_used": 40, + "is_valid": True, + }, + ], + ) + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)), + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + with engine.connect() as connection: + quota_used_by_id = dict(connection.execute(select(Provider.id, Provider.quota_used)).all()) + + assert quota_used_by_id == { + "matching-provider": 13, + "other-tenant": 20, + "other-provider": 30, + "custom-provider": 40, + } + + with engine.begin() as connection: + connection.execute( + Provider.__table__.update().where(Provider.id == "matching-provider").values(quota_limit=13, quota_used=13) + ) + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)), + pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."), + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + with engine.connect() as connection: + exhausted_quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider")) + + assert exhausted_quota_used == 13 + + +def test_deduct_llm_quota_for_model_caps_free_quota_and_raises_when_usage_exceeds_remaining() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 3 + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.FREE, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.FREE, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + engine = create_engine("sqlite:///:memory:") + Provider.__table__.create(engine) + with engine.begin() as connection: + connection.execute( + Provider.__table__.insert(), + { + "id": "matching-provider", + "tenant_id": "tenant-id", + "provider_name": "openai", + "provider_type": ProviderType.SYSTEM, + "quota_type": ProviderQuotaType.FREE, + "quota_limit": 15, + "quota_used": 13, + "is_valid": True, + }, + ) + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("core.app.llm.quota.db", SimpleNamespace(engine=engine)), + pytest.raises(QuotaExceededError, match="Model provider openai quota exceeded."), + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + with engine.connect() as connection: + quota_used = connection.scalar(select(Provider.quota_used).where(Provider.id == "matching-provider")) + + assert quota_used == 15 + + +def test_deduct_llm_quota_for_model_ignores_unknown_quota_type() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 2 + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=SimpleNamespace( + current_quota_type="unexpected", + quota_configurations=[ + SimpleNamespace( + quota_type="unexpected", + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + ) + ], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, + patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker, + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + mock_deduct_credits.assert_not_called() + mock_sessionmaker.assert_not_called() + + +def test_deduct_llm_quota_for_model_ignores_custom_provider_configuration() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 2 + provider_configuration = SimpleNamespace( + using_provider_type=ProviderType.CUSTOM, + system_configuration=SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[], + ), + ) + provider_manager = MagicMock() + provider_manager.get_configurations.return_value.get.return_value = provider_configuration + + with ( + patch("core.app.llm.quota.create_plugin_provider_manager", return_value=provider_manager), + patch("services.credit_pool_service.CreditPoolService.deduct_credits_capped") as mock_deduct_credits, + patch("core.app.llm.quota.sessionmaker") as mock_sessionmaker, + ): + deduct_llm_quota_for_model( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + mock_deduct_credits.assert_not_called() + mock_sessionmaker.assert_not_called() + + +def test_ensure_llm_quota_available_wrapper_warns_and_delegates() -> None: + model_instance = SimpleNamespace( + provider="openai", + model_name="gpt-4o", + provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")), + model_type_instance=SimpleNamespace(model_type=ModelType.LLM), + ) + + with ( + pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"), + patch("core.app.llm.quota.ensure_llm_quota_available_for_model") as mock_ensure, + ): + ensure_llm_quota_available(model_instance=model_instance) + + mock_ensure.assert_called_once_with( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + ) + + +def test_ensure_llm_quota_available_wrapper_rejects_non_llm_model_instances() -> None: + model_instance = SimpleNamespace( + provider="openai", + model_name="gpt-4o", + provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace(tenant_id="tenant-id")), + model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING), + ) + + with ( + pytest.deprecated_call(match="ensure_llm_quota_available\\(model_instance=.*deprecated"), + pytest.raises(ValueError, match="only support LLM model instances"), + ): + ensure_llm_quota_available(model_instance=model_instance) + + +def test_deduct_llm_quota_wrapper_warns_and_delegates() -> None: + usage = LLMUsage.empty_usage() + usage.total_tokens = 7 + model_instance = SimpleNamespace( + provider="openai", + model_name="gpt-4o", + model_type_instance=SimpleNamespace(model_type=ModelType.LLM), + provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace()), + ) + + with ( + pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"), + patch("core.app.llm.quota.deduct_llm_quota_for_model") as mock_deduct, + ): + deduct_llm_quota( + tenant_id="tenant-id", + model_instance=model_instance, + usage=usage, + ) + + mock_deduct.assert_called_once_with( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=usage, + ) + + +def test_deduct_llm_quota_wrapper_rejects_non_llm_model_instances() -> None: + usage = LLMUsage.empty_usage() + model_instance = SimpleNamespace( + provider="openai", + model_name="gpt-4o", + model_type_instance=SimpleNamespace(model_type=ModelType.TEXT_EMBEDDING), + provider_model_bundle=SimpleNamespace(configuration=SimpleNamespace()), + ) + + with ( + pytest.deprecated_call(match="deduct_llm_quota\\(tenant_id=.*deprecated"), + pytest.raises(ValueError, match="only support LLM model instances"), + ): + deduct_llm_quota( + tenant_id="tenant-id", + model_instance=model_instance, + usage=usage, + ) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index 7c9f174129..addce649d5 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -8,9 +8,9 @@ from graphon.enums import BuiltinNodeTypes class DummyNode: - def __init__(self, *, node_id, config, graph_init_params, graph_runtime_state, **kwargs): + def __init__(self, *, node_id, data, graph_init_params, graph_runtime_state, **kwargs): self.id = node_id - self.config = config + self.data = data self.graph_init_params = graph_init_params self.graph_runtime_state = graph_runtime_state self.kwargs = kwargs diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index 23fe682017..9cefa97bef 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -7,6 +7,7 @@ import pytest from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.ops.ops_trace_manager import TraceTask, TraceTaskName from core.workflow.system_variables import SystemVariableKey, build_system_variables from graphon.entities import WorkflowNodeExecution from graphon.entities.pause_reason import SchedulingPause @@ -60,7 +61,10 @@ def _make_layer( workflow_execution_id="run-id", conversation_id="conv-id", ) - runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0) + runtime_state = GraphRuntimeState( + variable_pool=VariablePool.from_bootstrap(system_variables=system_variables), + start_at=0.0, + ) read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state) application_generate_entity = WorkflowAppGenerateEntity.model_construct( @@ -214,6 +218,59 @@ class TestWorkflowPersistenceLayer: assert exec_repo.saved[-1].status == WorkflowExecutionStatus.FAILED assert trace_tasks + def test_handle_graph_run_succeeded_enqueues_parent_trace_context(self, monkeypatch): + trace_tasks: list[TraceTask] = [] + trace_manager = SimpleNamespace(user_id="user", add_trace_task=lambda task: trace_tasks.append(task)) + layer, _, _, _ = _make_layer( + extras={ + "external_trace_id": "trace", + "parent_trace_context": { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + }, + }, + trace_manager=trace_manager, + ) + layer._handle_graph_run_started() + + captured: dict[str, object] = {} + + def fake_workflow_trace( + self: TraceTask, + *, + workflow_run_id: str | None, + conversation_id: str | None, + user_id: str | None, + total_tokens_override: int | None = None, + ): + captured["trace_type"] = self.trace_type + captured["external_trace_id"] = self.kwargs.get("external_trace_id") + captured["parent_trace_context"] = self.kwargs.get("parent_trace_context") + captured["workflow_run_id"] = workflow_run_id + return {"ok": True} + + monkeypatch.setattr(TraceTask, "workflow_trace", fake_workflow_trace) + + layer._handle_graph_run_succeeded(GraphRunSucceededEvent(outputs={"ok": True})) + + assert trace_tasks + trace_task = trace_tasks[0] + assert trace_task.trace_type == TraceTaskName.WORKFLOW_TRACE + assert trace_task.kwargs["external_trace_id"] == "trace" + assert trace_task.kwargs["parent_trace_context"] == { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + } + + trace_task.execute() + + assert captured["trace_type"] == TraceTaskName.WORKFLOW_TRACE + assert captured["external_trace_id"] == "trace" + assert captured["parent_trace_context"] == { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + } + def test_handle_graph_run_aborted_sets_status(self): layer, exec_repo, _, _ = _make_layer() layer._handle_graph_run_started() diff --git a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py index 4f39d38831..cee7d46083 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py @@ -34,20 +34,6 @@ class TestDatasourceFileManager: assert f"nonce={mock_urandom.return_value.hex()}" in signed_url assert "sign=" in signed_url - @patch("core.datasource.datasource_file_manager.time.time") - @patch("core.datasource.datasource_file_manager.os.urandom") - @patch("core.datasource.datasource_file_manager.dify_config") - def test_sign_file_empty_secret(self, mock_config, mock_urandom, mock_time): - # Setup - mock_config.FILES_URL = "http://localhost:5001" - mock_config.SECRET_KEY = None # Empty secret - mock_time.return_value = 1700000000 - mock_urandom.return_value = b"1234567890abcdef" - - # Execute - signed_url = DatasourceFileManager.sign_file("file_id", ".png") - assert "sign=" in signed_url - @patch("core.datasource.datasource_file_manager.time.time") @patch("core.datasource.datasource_file_manager.dify_config") def test_verify_file(self, mock_config, mock_time): @@ -76,25 +62,6 @@ class TestDatasourceFileManager: mock_time.return_value = 1700000500 # 700 seconds after timestamp (300 is timeout) assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is False - @patch("core.datasource.datasource_file_manager.time.time") - @patch("core.datasource.datasource_file_manager.dify_config") - def test_verify_file_empty_secret(self, mock_config, mock_time): - # Setup - mock_config.SECRET_KEY = "" # Empty string secret - mock_config.FILES_ACCESS_TIMEOUT = 300 - mock_time.return_value = 1700000000 - - datasource_file_id = "file_id_123" - timestamp = "1699999800" - nonce = "some_nonce" - - # Calculate with empty secret - data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" - sign = hmac.new(b"", data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True - @patch("core.datasource.datasource_file_manager.db") @patch("core.datasource.datasource_file_manager.storage") @patch("core.datasource.datasource_file_manager.uuid4") diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index a28143026f..1b714d6830 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -354,7 +354,8 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: with _patched_session(mock_session): with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), ): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): with patch( @@ -379,7 +380,10 @@ def test_validate_provider_credentials_without_credential_id() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), + ): validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} @@ -426,23 +430,37 @@ def test_switch_preferred_provider_type_creates_record_when_missing() -> None: def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: configuration = _build_provider_configuration() - mock_factory = Mock() mock_model_type_instance = Mock() mock_schema = _build_ai_model("gpt-4o") - mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory = Mock() + mock_factory.get_provider_schema.return_value = configuration.provider mock_factory.get_model_schema.return_value = mock_schema + mock_assembly = Mock() + mock_assembly.model_runtime = Mock() + mock_assembly.model_provider_factory = mock_factory - with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", - return_value=mock_factory, - ) as mock_factory_builder: + with ( + patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=mock_assembly, + ) as mock_assembly_builder, + patch( + "core.entities.provider_configuration.create_model_type_instance", + return_value=mock_model_type_instance, + ) as mock_model_builder, + ): model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema - assert mock_factory_builder.call_count == 2 - mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) + assert mock_assembly_builder.call_count == 2 + mock_factory.get_provider_schema.assert_called_once_with(provider="openai") + mock_model_builder.assert_called_once_with( + runtime=mock_assembly.model_runtime, + provider_schema=configuration.provider, + model_type=ModelType.LLM, + ) mock_factory.get_model_schema.assert_called_once_with( provider="openai", model_type=ModelType.LLM, @@ -456,17 +474,21 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non bound_runtime = Mock() configuration.bind_model_runtime(bound_runtime) - mock_factory = Mock() mock_model_type_instance = Mock() mock_schema = _build_ai_model("gpt-4o") - mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory = Mock() + mock_factory.get_provider_schema.return_value = configuration.provider mock_factory.get_model_schema.return_value = mock_schema with ( patch( "core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory ) as mock_factory_cls, - patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder, + patch("core.entities.provider_configuration.create_plugin_model_assembly") as mock_assembly_builder, + patch( + "core.entities.provider_configuration.create_model_type_instance", + return_value=mock_model_type_instance, + ) as mock_model_builder, ): model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) @@ -474,8 +496,14 @@ def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> Non assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema assert mock_factory_cls.call_count == 2 - mock_factory_cls.assert_called_with(model_runtime=bound_runtime) - mock_factory_builder.assert_not_called() + mock_factory_cls.assert_called_with(runtime=bound_runtime) + mock_assembly_builder.assert_not_called() + mock_factory.get_provider_schema.assert_called_once_with(provider="openai") + mock_model_builder.assert_called_once_with( + runtime=bound_runtime, + provider_schema=configuration.provider, + model_type=ModelType.LLM, + ) def test_get_provider_model_returns_none_when_model_not_found() -> None: @@ -504,7 +532,10 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N mock_factory = Mock() mock_factory.get_provider_schema.return_value = provider_schema - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), + ): all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) @@ -722,7 +753,8 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None: with _patched_session(mock_session): with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), ): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): validated = configuration.validate_provider_credentials( @@ -1069,7 +1101,8 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless with _patched_session(mock_session): with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), ): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): @@ -1083,7 +1116,10 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory2 = Mock() mock_factory2.model_credentials_validate.return_value = {"region": "us"} - with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory2): + with patch( + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory2), + ): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, model="gpt-4o", @@ -1575,7 +1611,8 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing() with _patched_session(mock_session): with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), ): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_provider_credentials( @@ -1701,7 +1738,8 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No with _patched_session(mock_session): with patch( - "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + "core.entities.provider_configuration.create_plugin_model_assembly", + return_value=SimpleNamespace(model_runtime=Mock(), model_provider_factory=mock_factory), ): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( diff --git a/api/tests/unit_tests/core/helper/test_moderation.py b/api/tests/unit_tests/core/helper/test_moderation.py index a0dfa86d20..c33002329b 100644 --- a/api/tests/unit_tests/core/helper/test_moderation.py +++ b/api/tests/unit_tests/core/helper/test_moderation.py @@ -68,8 +68,8 @@ def test_check_moderation_returns_true_when_model_accepts_text(mocker: MockerFix mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") moderation_model = SimpleNamespace(invoke=lambda **invoke_kwargs: invoke_kwargs["text"] == "chunk") - factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model) - mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model) + mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly) assert ( check_moderation( @@ -91,7 +91,7 @@ def test_check_moderation_returns_true_when_text_is_empty(mocker: MockerFixture) provider_map={openai_provider: hosting_openai}, ), ) - factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_provider_factory") + factory_mock = mocker.patch("core.helper.moderation.create_plugin_model_assembly") choice_mock = mocker.patch("core.helper.moderation.secrets.choice") assert ( @@ -119,8 +119,8 @@ def test_check_moderation_returns_false_when_model_rejects_text(mocker: MockerFi mocker.patch("core.helper.moderation.secrets.choice", return_value="chunk") moderation_model = SimpleNamespace(invoke=lambda **_invoke_kwargs: False) - factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: moderation_model) - mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: moderation_model) + mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly) assert ( check_moderation( @@ -147,8 +147,8 @@ def test_check_moderation_raises_bad_request_when_provider_call_fails(mocker: Mo failing_model = SimpleNamespace( invoke=lambda **_invoke_kwargs: (_ for _ in ()).throw(RuntimeError("boom")), ) - factory = SimpleNamespace(get_model_type_instance=lambda **_factory_kwargs: failing_model) - mocker.patch("core.helper.moderation.create_plugin_model_provider_factory", return_value=factory) + assembly = SimpleNamespace(create_model_type_instance=lambda **_factory_kwargs: failing_model) + mocker.patch("core.helper.moderation.create_plugin_model_assembly", return_value=assembly) with pytest.raises(InvokeBadRequestError, match="Rate limit exceeded, please try again later."): check_moderation( diff --git a/api/tests/unit_tests/core/helper/test_trace_id_helper.py b/api/tests/unit_tests/core/helper/test_trace_id_helper.py index 27bfe1af05..96e2d44730 100644 --- a/api/tests/unit_tests/core/helper/test_trace_id_helper.py +++ b/api/tests/unit_tests/core/helper/test_trace_id_helper.py @@ -1,6 +1,12 @@ import pytest -from core.helper.trace_id_helper import extract_external_trace_id_from_args, get_external_trace_id, is_valid_trace_id +from core.helper.trace_id_helper import ( + ParentTraceContext, + extract_external_trace_id_from_args, + extract_parent_trace_context_from_args, + get_external_trace_id, + is_valid_trace_id, +) class DummyRequest: @@ -84,3 +90,92 @@ class TestTraceIdHelper: def test_extract_external_trace_id_from_args(self, args, expected): """Test extraction of external_trace_id from args mapping""" assert extract_external_trace_id_from_args(args) == expected + + @pytest.mark.parametrize( + ("args", "expected"), + [ + ( + { + "parent_trace_context": { + "parent_workflow_run_id": "workflow-run-1", + "parent_node_execution_id": "node-execution-1", + } + }, + { + "parent_trace_context": ParentTraceContext( + parent_workflow_run_id="workflow-run-1", + parent_node_execution_id="node-execution-1", + ) + }, + ), + ( + { + "parent_trace_context": { + "parent_workflow_run_id": "workflow-run-1", + } + }, + {}, + ), + ( + { + "parent_trace_context": { + "parent_node_execution_id": "node-execution-1", + } + }, + {}, + ), + ( + { + "parent_trace_context": { + "parent_workflow_run_id": 123, + "parent_node_execution_id": "node-execution-1", + } + }, + {}, + ), + ( + { + "parent_trace_context": { + "parent_workflow_run_id": "workflow-run-1", + "parent_node_execution_id": None, + } + }, + {}, + ), + ({}, {}), + ], + ) + def test_extract_parent_trace_context_from_args(self, args, expected): + """Test extraction of parent_trace_context from args mapping""" + assert extract_parent_trace_context_from_args(args) == expected + + def test_extract_parent_trace_context_returns_typed_context(self): + """Parent trace context is parsed into a Pydantic value object.""" + result = extract_parent_trace_context_from_args( + { + "parent_trace_context": { + "parent_workflow_run_id": "workflow-run-1", + "parent_node_execution_id": "node-execution-1", + } + } + ) + + assert result == { + "parent_trace_context": ParentTraceContext( + parent_workflow_run_id="workflow-run-1", + parent_node_execution_id="node-execution-1", + ) + } + + def test_extract_parent_trace_context_rejects_incomplete_typed_context(self): + """Typed parent trace context follows the same completeness rule as raw mappings.""" + result = extract_parent_trace_context_from_args( + { + "parent_trace_context": ParentTraceContext( + parent_workflow_run_id="workflow-run-1", + parent_node_execution_id=None, + ) + } + ) + + assert result == {} diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index f459250b8e..72c24bda96 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -198,6 +198,48 @@ class TestBuildPromptMessageWithFiles: assert isinstance(result.content[-1], TextPromptMessageContent) assert result.content[-1].data == "user text" + def test_replay_does_not_pass_config_to_file_factory(self): + """Replay contract: history files were validated on upload, so this + path must not forward a FileUploadConfig. The factory's signature + no longer accepts ``config``; this test guards against a future + regression that re-introduces it.""" + conv = _make_conversation(AppMode.CHAT) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = None + + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=MagicMock(), + ) as mock_build, + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ), + ): + mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="user text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=True, + ) + + mock_build.assert_called_once() + assert "config" not in mock_build.call_args.kwargs + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) def test_chat_mode_with_files_assistant_message(self, mode): """When files are present, returns AssistantPromptMessage with list content.""" diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py index c4fd970562..2b51dc8182 100644 --- a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -2,6 +2,7 @@ from unittest.mock import Mock import pytest +from core.plugin.impl.model_runtime_factory import create_model_type_instance from graphon.model_runtime.entities.common_entities import I18nObject from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ( @@ -73,7 +74,7 @@ def test_model_provider_factory_resolves_runtime_provider_name() -> None: supported_model_types=[ModelType.LLM], configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], ) - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider])) provider_schema = factory.get_model_provider("openai") @@ -98,7 +99,7 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], ), ] - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers)) provider_schema = factory.get_model_provider("openai") @@ -107,8 +108,8 @@ def test_model_provider_factory_resolves_canonical_short_name_independent_of_pro def test_model_provider_factory_requires_runtime() -> None: - with pytest.raises(ValueError, match="model_runtime is required"): - ModelProviderFactory(model_runtime=None) # type: ignore[arg-type] + with pytest.raises(ValueError, match="runtime is required"): + ModelProviderFactory(runtime=None) # type: ignore[arg-type] def test_model_provider_factory_get_providers_returns_runtime_providers() -> None: @@ -119,7 +120,7 @@ def test_model_provider_factory_get_providers_returns_runtime_providers() -> Non supported_model_types=[ModelType.LLM], ) ] - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers)) result = factory.get_providers() @@ -133,7 +134,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup provider_name="openai", supported_model_types=[ModelType.LLM], ) - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + factory = ModelProviderFactory(runtime=_FakeModelRuntime([provider])) result = factory.get_provider_schema("openai") @@ -142,7 +143,7 @@ def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup def test_model_provider_factory_raises_for_unknown_provider() -> None: factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime( + runtime=_FakeModelRuntime( [ _build_provider( provider="langgenius/openai/openai", @@ -172,7 +173,7 @@ def test_model_provider_factory_get_models_filters_provider_and_model_type() -> models=[_build_model("rerank-v3", ModelType.RERANK)], ), ] - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers)) results = factory.get_models(provider="openai", model_type=ModelType.LLM) @@ -196,7 +197,7 @@ def test_model_provider_factory_get_models_skips_providers_without_requested_mod models=[_build_model("eleven_multilingual_v2", ModelType.TTS)], ), ] - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers)) results = factory.get_models(model_type=ModelType.TTS) @@ -214,7 +215,7 @@ def test_model_provider_factory_get_models_without_model_type_keeps_all_provider models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], ) ] - factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + factory = ModelProviderFactory(runtime=_FakeModelRuntime(providers)) results = factory.get_models(provider="openai") @@ -242,7 +243,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None: ) ] ) - factory = ModelProviderFactory(model_runtime=runtime) + factory = ModelProviderFactory(runtime=runtime) filtered = factory.provider_credentials_validate( provider="openai", @@ -258,7 +259,7 @@ def test_model_provider_factory_validates_provider_credentials() -> None: def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None: factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime( + runtime=_FakeModelRuntime( [ _build_provider( provider="langgenius/openai/openai", @@ -294,7 +295,7 @@ def test_model_provider_factory_validates_model_credentials() -> None: ) ] ) - factory = ModelProviderFactory(model_runtime=runtime) + factory = ModelProviderFactory(runtime=runtime) filtered = factory.model_credentials_validate( provider="openai", @@ -314,7 +315,7 @@ def test_model_provider_factory_validates_model_credentials() -> None: def test_model_provider_factory_model_credentials_validate_requires_schema() -> None: factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime( + runtime=_FakeModelRuntime( [ _build_provider( provider="langgenius/openai/openai", @@ -346,7 +347,7 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider ) runtime.get_model_schema.return_value = "schema" runtime.get_provider_icon.return_value = (b"icon", "image/png") - factory = ModelProviderFactory(model_runtime=runtime) + factory = ModelProviderFactory(runtime=runtime) assert ( factory.get_model_schema( @@ -382,39 +383,43 @@ def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider (ModelType.TTS, TTSModel), ], ) -def test_model_provider_factory_builds_model_type_instances( +def test_create_model_type_instance_builds_model_wrappers( model_type: ModelType, expected_type: type[object], ) -> None: - factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime( - [ - _build_provider( - provider="langgenius/openai/openai", - provider_name="openai", - supported_model_types=[model_type], - ) - ] - ) + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[model_type], + ) + ] ) - instance = factory.get_model_type_instance("openai", model_type) + instance = create_model_type_instance( + runtime=runtime, + provider_schema=runtime.fetch_model_providers()[0], + model_type=model_type, + ) assert isinstance(instance, expected_type) -def test_model_provider_factory_rejects_unsupported_model_type() -> None: - factory = ModelProviderFactory( - model_runtime=_FakeModelRuntime( - [ - _build_provider( - provider="langgenius/openai/openai", - provider_name="openai", - supported_model_types=[ModelType.LLM], - ) - ] - ) +def test_create_model_type_instance_rejects_unsupported_model_type() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] ) with pytest.raises(ValueError, match="Unsupported model type: unsupported"): - factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type] + create_model_type_instance( + runtime=runtime, + provider_schema=runtime.fetch_model_providers()[0], + model_type="unsupported", # type: ignore[arg-type] + ) diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py index 7491e79f30..52da674f06 100644 --- a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -31,6 +31,6 @@ def test_plugin_model_assembly_reuses_single_runtime_across_views(): assert assembly.model_manager is model_manager mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") - mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime) + mock_provider_factory_cls.assert_called_once_with(runtime=runtime) mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py index 88bf555594..b1ecaa4ead 100644 --- a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -3,7 +3,7 @@ import datetime import uuid from types import SimpleNamespace -from unittest.mock import Mock, sentinel +from unittest.mock import Mock, patch, sentinel import pytest @@ -13,6 +13,8 @@ from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity @@ -146,7 +148,31 @@ class TestPluginModelRuntime: def test_invoke_llm_resolves_plugin_fields(self) -> None: client = Mock(spec=PluginModelClient) - client.invoke_llm.return_value = sentinel.result + usage = LLMUsage.empty_usage() + client.invoke_llm.return_value = iter( + [ + LLMResultChunk( + model="gpt-4o-mini", + prompt_messages=[], + system_fingerprint="fp-plugin", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content="plugin "), + ), + ), + LLMResultChunk( + model="gpt-4o-mini", + prompt_messages=[], + system_fingerprint="fp-plugin", + delta=LLMResultChunkDelta( + index=1, + message=AssistantPromptMessage(content="response"), + usage=usage, + finish_reason="stop", + ), + ), + ] + ) runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) result = runtime.invoke_llm( @@ -160,7 +186,11 @@ class TestPluginModelRuntime: stream=False, ) - assert result is sentinel.result + assert result.model == "gpt-4o-mini" + assert result.prompt_messages == [] + assert result.message.content == "plugin response" + assert result.usage == usage + assert result.system_fingerprint == "fp-plugin" client.invoke_llm.assert_called_once_with( tenant_id="tenant", user_id="user", @@ -175,6 +205,38 @@ class TestPluginModelRuntime: stream=False, ) + def test_invoke_llm_returns_plugin_stream_directly(self) -> None: + client = Mock(spec=PluginModelClient) + stream_result = iter([]) + client.invoke_llm.return_value = stream_result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.invoke_llm( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=("END",), + stream=True, + ) + + assert result is stream_result + client.invoke_llm.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=["END"], + stream=True, + ) + def test_invoke_llm_rejects_per_call_user_override(self) -> None: client = Mock(spec=PluginModelClient) client.invoke_llm.return_value = sentinel.result @@ -267,6 +329,129 @@ def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: client.get_model_schema.assert_not_called() +def test_structured_output_adapter_invokes_bound_runtime_streaming() -> None: + runtime = Mock() + runtime.invoke_llm.return_value = sentinel.stream_result + adapter = model_runtime_module._PluginStructuredOutputModelInstance( + runtime=runtime, + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + tool = Mock() + + result = adapter.invoke_llm( + prompt_messages=[], + model_parameters=None, + tools=[tool], + stop=["END"], + stream=True, + callbacks=sentinel.callbacks, + ) + + assert result is sentinel.stream_result + runtime.invoke_llm.assert_called_once_with( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={}, + prompt_messages=[], + tools=[tool], + stop=["END"], + stream=True, + ) + + +def test_structured_output_adapter_invokes_bound_runtime_non_streaming() -> None: + runtime = Mock() + runtime.invoke_llm.return_value = sentinel.result + adapter = model_runtime_module._PluginStructuredOutputModelInstance( + runtime=runtime, + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + result = adapter.invoke_llm( + prompt_messages=[], + model_parameters={"temperature": 0}, + tools=None, + stop=None, + stream=False, + ) + + assert result is sentinel.result + runtime.invoke_llm.assert_called_once_with( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + +def test_invoke_llm_with_structured_output_delegates_with_bound_adapter() -> None: + client = Mock(spec=PluginModelClient) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + schema = _build_model_schema() + runtime.get_model_schema = Mock(return_value=schema) # type: ignore[method-assign] + + with patch.object( + model_runtime_module, + "invoke_llm_with_structured_output_helper", + return_value=sentinel.structured_result, + ) as mock_helper: + result = runtime.invoke_llm_with_structured_output( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + json_schema={"type": "object"}, + model_parameters={"temperature": 0}, + prompt_messages=[], + stop=("END",), + stream=False, + ) + + assert result is sentinel.structured_result + runtime.get_model_schema.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + helper_kwargs = mock_helper.call_args.kwargs + assert helper_kwargs["provider"] == "langgenius/openai/openai" + assert helper_kwargs["model_schema"] == schema + assert helper_kwargs["json_schema"] == {"type": "object"} + assert helper_kwargs["model_parameters"] == {"temperature": 0} + assert helper_kwargs["prompt_messages"] == [] + assert helper_kwargs["tools"] is None + assert helper_kwargs["stop"] == ["END"] + assert helper_kwargs["stream"] is False + assert isinstance(helper_kwargs["model_instance"], model_runtime_module._PluginStructuredOutputModelInstance) + + +def test_invoke_llm_with_structured_output_raises_when_model_schema_is_missing() -> None: + client = Mock(spec=PluginModelClient) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + runtime.get_model_schema = Mock(return_value=None) # type: ignore[method-assign] + + with pytest.raises(ValueError, match="Model schema not found for gpt-4o-mini"): + runtime.invoke_llm_with_structured_output( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + json_schema={"type": "object"}, + model_parameters={}, + prompt_messages=[], + stop=None, + stream=False, + ) + + def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None: client = Mock(spec=PluginModelClient) schema = _build_model_schema() diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index d38213dd89..f72351ffa2 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -1038,7 +1038,7 @@ class TestRetrievalServiceInternals: assert any(doc.metadata["doc_id"] == "processed-doc" for doc in all_documents) processor_instance.invoke.assert_called_once() - @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + @patch("core.rag.datasource.retrieval_service.sign_upload_file_preview_url", return_value="signed://file") def test_get_segment_attachment_info_success(self, mock_sign): upload_file = SimpleNamespace( id="upload-1", @@ -1118,7 +1118,7 @@ class TestRetrievalServiceInternals: assert result == [] - @patch("core.rag.datasource.retrieval_service.sign_upload_file", return_value="signed://file") + @patch("core.rag.datasource.retrieval_service.sign_upload_file_preview_url", return_value="signed://file") def test_get_segment_attachment_infos_success(self, mock_sign): upload_file_1 = SimpleNamespace( id="upload-1", diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index b556ddf528..9334ad9b2f 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -4562,7 +4562,7 @@ class TestRetrieveCoverage: "core.rag.retrieval.dataset_retrieval.RetrievalService.format_retrieval_documents", return_value=[record], ), - patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), + patch("core.rag.retrieval.dataset_retrieval.sign_upload_file_preview_url", return_value="https://signed"), patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, ): bound_model_instance = Mock() diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 02f12fb3b4..e84fcba3d9 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -289,7 +289,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc result = manager.get_default_model("tenant-id", ModelType.LLM) - mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime) assert result is not None assert result.model == "gpt-4" assert result.provider.provider == "openai" @@ -316,7 +316,7 @@ def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mock result = manager.get_configurations("tenant-id") expected_alias = str(ModelProviderID("openai")) - mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime) assert result.tenant_id == "tenant-id" assert expected_alias in provider_records assert expected_alias in provider_model_records @@ -402,7 +402,7 @@ def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerF assert first is second mock_get_all_providers.assert_called_once_with("tenant-id") - mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_factory_cls.assert_called_once_with(runtime=manager._model_runtime) mock_provider_configuration.assert_called_once() provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) diff --git a/api/tests/unit_tests/core/tools/test_signature.py b/api/tests/unit_tests/core/tools/test_signature.py index 353988d7a6..a75fdee908 100644 --- a/api/tests/unit_tests/core/tools/test_signature.py +++ b/api/tests/unit_tests/core/tools/test_signature.py @@ -9,7 +9,7 @@ import pytest from core.tools.signature import ( get_signed_file_url_for_plugin, sign_tool_file, - sign_upload_file, + sign_upload_file_preview_url, verify_plugin_file_signature, verify_tool_file_signature, ) @@ -89,32 +89,32 @@ def test_verify_tool_file_signature_rejects_expired_signature(monkeypatch: pytes assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is False -def test_sign_upload_file_prefers_internal_url(monkeypatch: pytest.MonkeyPatch) -> None: +def test_sign_upload_file_preview_url_uses_files_url(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x03" * 16) monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") - url = sign_upload_file("upload-id", ".png") + url = sign_upload_file_preview_url("upload-id", ".png") parsed = urlparse(url) query = parse_qs(parsed.query) - assert parsed.netloc == "internal.example.com" + assert parsed.netloc == "files.example.com" assert parsed.path == "/files/upload-id/image-preview" assert query["timestamp"][0] assert query["nonce"][0] assert query["sign"][0] -def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatch) -> None: +def test_sign_upload_file_preview_url_ignores_internal_files_url(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x05" * 16) monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") - monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") - url = sign_upload_file("upload-id", ".png") + url = sign_upload_file_preview_url("upload-id", ".png") parsed = urlparse(url) query = parse_qs(parsed.query) diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 72a73dd936..6c563b0912 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -147,6 +147,142 @@ def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.Monke assert call_kwargs["pause_state_config"] is None +def test_workflow_tool_passes_parent_trace_context_from_runtime(monkeypatch: pytest.MonkeyPatch): + """Ensure nested workflow runtime metadata is forwarded as parent trace context.""" + tool = _build_tool() + tool.set_parent_trace_context( + parent_workflow_run_id="outer-workflow-run-1", + parent_node_execution_id="outer-node-execution-1", + ) + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list(tool.invoke("test_user", {})) + + call_kwargs = generate_mock.call_args.kwargs + assert call_kwargs["args"]["parent_trace_context"].model_dump() == { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + } + + +def test_workflow_tool_keeps_user_inputs_named_like_trace_runtime_keys(monkeypatch: pytest.MonkeyPatch): + """Ensure private trace context does not overwrite same-named workflow inputs.""" + tool = _build_tool() + tool.entity.parameters = [ + ToolParameter.get_simple_instance( + name="outer_workflow_run_id", + llm_description="User workflow input", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ), + ToolParameter.get_simple_instance( + name="outer_node_execution_id", + llm_description="User node input", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ), + ] + tool.set_parent_trace_context( + parent_workflow_run_id="outer-workflow-run-1", + parent_node_execution_id="outer-node-execution-1", + ) + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list( + tool.invoke( + "test_user", + { + "outer_workflow_run_id": "user-workflow-input", + "outer_node_execution_id": "user-node-input", + }, + ) + ) + + call_kwargs = generate_mock.call_args.kwargs + assert call_kwargs["args"]["inputs"]["outer_workflow_run_id"] == "user-workflow-input" + assert call_kwargs["args"]["inputs"]["outer_node_execution_id"] == "user-node-input" + assert call_kwargs["args"]["parent_trace_context"].model_dump() == { + "parent_workflow_run_id": "outer-workflow-run-1", + "parent_node_execution_id": "outer-node-execution-1", + } + + +def test_workflow_tool_can_clear_parent_trace_context(monkeypatch: pytest.MonkeyPatch): + """Ensure reused WorkflowTool instances do not keep stale parent trace context.""" + tool = _build_tool() + tool.set_parent_trace_context( + parent_workflow_run_id="outer-workflow-run-1", + parent_node_execution_id="outer-node-execution-1", + ) + tool.clear_parent_trace_context() + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list(tool.invoke("test_user", {})) + + call_kwargs = generate_mock.call_args.kwargs + assert "parent_trace_context" not in call_kwargs["args"] + + +@pytest.mark.parametrize( + "runtime_parameters", + [ + {}, + {"outer_workflow_run_id": "outer-workflow-run-1"}, + {"outer_node_execution_id": "outer-node-execution-1"}, + {"outer_workflow_run_id": None, "outer_node_execution_id": None}, + ], +) +def test_workflow_tool_omits_parent_trace_context_when_runtime_is_incomplete( + monkeypatch: pytest.MonkeyPatch, + runtime_parameters: dict[str, Any], +): + """Ensure incomplete runtime metadata does not leak parent trace context into generator args.""" + tool = _build_tool() + tool.runtime.runtime_parameters = runtime_parameters + + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + + generate_mock = MagicMock(return_value={"data": {}}) + monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock) + monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) + + list(tool.invoke("test_user", {})) + + call_kwargs = generate_mock.call_args.kwargs + assert "parent_trace_context" not in call_kwargs["args"] + + def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should generate variable messages when there are outputs""" tool = _build_tool() diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index baa2ac2dc7..009899a92d 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -233,8 +233,6 @@ class TestSegmentTypeAdditionalMethods: assert SegmentType.GROUP.is_valid([StringSegment(value="b")]) is True assert SegmentType.GROUP.is_valid(["not-segment"]) is False - def test_unreachable_assertion_branch(self, monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(SegmentType, "is_array_type", lambda self: False) - - with pytest.raises(AssertionError, match="unreachable"): - SegmentType.ARRAY_STRING.is_valid(["a"]) + def test_unreachable_assertion_branch(self): + with pytest.raises(AssertionError, match="Expected code to be unreachable"): + SegmentType.is_valid("not-a-segment-type", None) # type: ignore[arg-type] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 5d6667257f..12c7f8113c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -1,12 +1,11 @@ +import logging import threading from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError -from core.model_manager import ModelInstance from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from graphon.graph_engine.entities.commands import CommandType from graphon.graph_events import NodeRunSucceededEvent @@ -14,17 +13,7 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import NodeRunResult -def _build_dify_context() -> DifyRunContext: - return DifyRunContext( - tenant_id="tenant-id", - app_id="app-id", - user_id="user-id", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - ) - - -def _build_succeeded_event() -> NodeRunSucceededEvent: +def _build_succeeded_event(*, provider: str = "openai", model_name: str = "gpt-4o") -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="execution-id", node_id="llm-node-id", @@ -32,113 +21,162 @@ def _build_succeeded_event() -> NodeRunSucceededEvent: start_at=datetime.now(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={"question": "hello"}, + inputs={ + "question": "hello", + "model_provider": provider, + "model_name": model_name, + }, llm_usage=LLMUsage.empty_usage(), ), ) -def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: - raw_model_instance = ModelInstance.__new__(ModelInstance) - return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance +def _build_public_model_identity(*, provider: str = "openai", model_name: str = "gpt-4o") -> SimpleNamespace: + return SimpleNamespace(provider=provider, name=model_name) + + +def _build_node_data(*, model: SimpleNamespace | None = None) -> SimpleNamespace: + return SimpleNamespace( + error_strategy=None, + retry_config=SimpleNamespace(retry_enabled=False), + model=model, + ) + + +def _build_node(*, node_type: BuiltinNodeTypes = BuiltinNodeTypes.LLM) -> MagicMock: + node = MagicMock() + node.id = "node-id" + node.execution_id = "execution-id" + node.node_type = node_type + node.node_data = _build_node_data(model=_build_public_model_identity()) + node.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model") + return node + + +class _RunnableQuotaNode: + id = "node-id" + execution_id = "execution-id" + node_type = BuiltinNodeTypes.LLM + title = "LLM node" + + def __init__(self, *, stop_event: threading.Event, node_data: SimpleNamespace | None = None) -> None: + self.node_data = node_data or _build_node_data(model=_build_public_model_identity()) + self.graph_runtime_state = SimpleNamespace(stop_event=stop_event) + self.original_run_called = False + + def _run(self) -> NodeRunResult: + self.original_run_called = True + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) def test_deduct_quota_called_for_successful_llm_node() -> None: - layer = LLMQuotaLayer() - node = MagicMock() - node.id = "llm-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.LLM - node.tenant_id = "tenant-id" - node.require_run_context_value.return_value = _build_dify_context() - node.model_instance, raw_model_instance = _build_wrapped_model_instance() - + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = _build_node(node_type=BuiltinNodeTypes.LLM) result_event = _build_succeeded_event() - with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct: layer.on_node_run_end(node=node, error=None, result_event=result_event) mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=raw_model_instance, + provider="openai", + model="gpt-4o", usage=result_event.node_run_result.llm_usage, ) def test_deduct_quota_called_for_question_classifier_node() -> None: - layer = LLMQuotaLayer() - node = MagicMock() - node.id = "question-classifier-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER - node.tenant_id = "tenant-id" - node.require_run_context_value.return_value = _build_dify_context() - node.model_instance, raw_model_instance = _build_wrapped_model_instance() + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = _build_node(node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER) + result_event = _build_succeeded_event(provider="anthropic", model_name="claude-3-7-sonnet") - result_event = _build_succeeded_event() - with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct: layer.on_node_run_end(node=node, error=None, result_event=result_event) mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=raw_model_instance, + provider="anthropic", + model="claude-3-7-sonnet", usage=result_event.node_run_result.llm_usage, ) def test_non_llm_node_is_ignored() -> None: - layer = LLMQuotaLayer() - node = MagicMock() - node.id = "start-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.START - node.tenant_id = "tenant-id" - node.require_run_context_value.return_value = _build_dify_context() - node._model_instance = object() - + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = _build_node(node_type=BuiltinNodeTypes.START) result_event = _build_succeeded_event() - with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: + + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct: layer.on_node_run_end(node=node, error=None, result_event=result_event) mock_deduct.assert_not_called() -def test_quota_error_is_handled_in_layer() -> None: - layer = LLMQuotaLayer() - node = MagicMock() - node.id = "llm-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.LLM - node.tenant_id = "tenant-id" - node.require_run_context_value.return_value = _build_dify_context() - node.model_instance = object() +def test_precheck_ignores_non_quota_node() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = _build_node(node_type=BuiltinNodeTypes.START) - result_event = _build_succeeded_event() - with patch( - "core.app.workflow.layers.llm_quota.deduct_llm_quota", - autospec=True, - side_effect=ValueError("quota exceeded"), - ): - layer.on_node_run_end(node=node, error=None, result_event=result_event) + with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check: + layer.on_node_run_start(node) + + mock_check.assert_not_called() -def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: - layer = LLMQuotaLayer() +def test_quota_error_is_handled_in_layer(caplog) -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") stop_event = threading.Event() layer.command_channel = MagicMock() - node = MagicMock() - node.id = "llm-node-id" - node.execution_id = "execution-id" - node.node_type = BuiltinNodeTypes.LLM - node.tenant_id = "tenant-id" - node.require_run_context_value.return_value = _build_dify_context() - node.model_instance, _ = _build_wrapped_model_instance() + node = _build_node(node_type=BuiltinNodeTypes.LLM) + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.stop_event = stop_event + result_event = _build_succeeded_event() + + with ( + caplog.at_level(logging.ERROR, logger="core.app.workflow.layers.llm_quota"), + patch( + "core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", + autospec=True, + side_effect=ValueError("quota exceeded"), + ) as mock_deduct, + ): + layer.on_node_run_end(node=node, error=None, result_event=result_event) + + mock_deduct.assert_called_once_with( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + usage=result_event.node_run_result.llm_usage, + ) + assert "LLM quota deduction failed, node_id=node-id" in caplog.text + assert not stop_event.is_set() + layer.command_channel.send_command.assert_not_called() + + +def test_send_abort_command_is_noop_without_channel_or_after_abort() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + + layer._send_abort_command(reason="no channel") + + layer.command_channel = MagicMock() + layer._abort_sent = True + layer._send_abort_command(reason="already aborted") + + layer.command_channel.send_command.assert_not_called() + + +def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = _build_node(node_type=BuiltinNodeTypes.LLM) node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event result_event = _build_succeeded_event() with patch( - "core.app.workflow.layers.llm_quota.deduct_llm_quota", + "core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True, side_effect=QuotaExceededError("No credits remaining"), ): @@ -152,19 +190,16 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: def test_quota_precheck_failure_aborts_workflow_immediately() -> None: - layer = LLMQuotaLayer() + layer = LLMQuotaLayer(tenant_id="tenant-id") stop_event = threading.Event() layer.command_channel = MagicMock() - node = MagicMock() - node.id = "llm-node-id" - node.node_type = BuiltinNodeTypes.LLM - node.model_instance, _ = _build_wrapped_model_instance() + node = _build_node(node_type=BuiltinNodeTypes.LLM) node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event with patch( - "core.app.workflow.layers.llm_quota.ensure_llm_quota_available", + "core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True, side_effect=QuotaExceededError("Model provider openai quota exceeded."), ): @@ -177,21 +212,140 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: assert abort_command.reason == "Model provider openai quota exceeded." -def test_quota_precheck_passes_without_abort() -> None: - layer = LLMQuotaLayer() +def test_quota_precheck_failure_blocks_current_node_run() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") stop_event = threading.Event() layer.command_channel = MagicMock() - node = MagicMock() - node.id = "llm-node-id" - node.node_type = BuiltinNodeTypes.LLM - node.model_instance, raw_model_instance = _build_wrapped_model_instance() + node = _RunnableQuotaNode(stop_event=stop_event) + + with patch( + "core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", + autospec=True, + side_effect=QuotaExceededError("Model provider openai quota exceeded."), + ): + layer.on_node_run_start(node) + + result = node._run() + assert not node.original_run_called + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "Model provider openai quota exceeded." + assert result.error_type == QuotaExceededError.__name__ + + +def test_missing_model_identity_blocks_current_node_run() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = _RunnableQuotaNode(stop_event=stop_event, node_data=_build_node_data()) + + with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check: + layer.on_node_run_start(node) + + result = node._run() + assert not node.original_run_called + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert result.error == "LLM quota check requires public node model identity before execution." + assert result.error_type == "LLMQuotaIdentityError" + mock_check.assert_not_called() + + +def test_quota_precheck_passes_without_abort() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = _build_node(node_type=BuiltinNodeTypes.LLM) node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event - with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check: + with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check: layer.on_node_run_start(node) assert not stop_event.is_set() - mock_check.assert_called_once_with(model_instance=raw_model_instance) + mock_check.assert_called_once_with( + tenant_id="tenant-id", + provider="openai", + model="gpt-4o", + ) layer.command_channel.send_command.assert_not_called() + + +def test_precheck_reads_model_identity_from_data_when_node_data_is_absent() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + node = SimpleNamespace( + id="node-id", + node_type=BuiltinNodeTypes.LLM, + data=_build_node_data(model=_build_public_model_identity(provider="anthropic", model_name="claude")), + ) + + with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check: + layer.on_node_run_start(node) + + mock_check.assert_called_once_with( + tenant_id="tenant-id", + provider="anthropic", + model="claude", + ) + + +def test_precheck_rejects_invalid_public_model_identity() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = _build_node(node_type=BuiltinNodeTypes.LLM) + node.node_data = _build_node_data(model=_build_public_model_identity(provider="", model_name="gpt-4o")) + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.stop_event = stop_event + + with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check: + layer.on_node_run_start(node) + + assert stop_event.is_set() + mock_check.assert_not_called() + layer.command_channel.send_command.assert_called_once() + + +def test_precheck_requires_public_node_model_config() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = _build_node(node_type=BuiltinNodeTypes.LLM) + node.node_data = _build_node_data() + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.stop_event = stop_event + + with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model", autospec=True) as mock_check: + layer.on_node_run_start(node) + + assert stop_event.is_set() + mock_check.assert_not_called() + layer.command_channel.send_command.assert_called_once() + abort_command = layer.command_channel.send_command.call_args.args[0] + assert abort_command.command_type == CommandType.ABORT + assert abort_command.reason == "LLM quota check requires public node model identity before execution." + + +def test_deduction_requires_public_event_model_identity() -> None: + layer = LLMQuotaLayer(tenant_id="tenant-id") + stop_event = threading.Event() + layer.command_channel = MagicMock() + + node = _build_node(node_type=BuiltinNodeTypes.LLM) + node.graph_runtime_state = MagicMock() + node.graph_runtime_state.stop_event = stop_event + result_event = _build_succeeded_event() + result_event.node_run_result.inputs = {"question": "hello"} + + with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model", autospec=True) as mock_deduct: + layer.on_node_run_end(node=node, error=None, result_event=result_event) + + assert stop_event.is_set() + mock_deduct.assert_not_called() + layer.command_channel.send_command.assert_called_once() + abort_command = layer.command_channel.send_command.call_args.args[0] + assert abort_command.command_type == CommandType.ABORT + assert abort_command.reason == "LLM quota deduction requires model identity in the node result event." diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 9f3e3b00b9..c721c7b0eb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -96,7 +96,7 @@ class MockNodeFactory(DifyNodeFactory): if node_type == BuiltinNodeTypes.CODE: mock_instance = mock_class( node_id=node_id, - config=resolved_node_data, + data=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -106,7 +106,7 @@ class MockNodeFactory(DifyNodeFactory): elif node_type == BuiltinNodeTypes.HTTP_REQUEST: mock_instance = mock_class( node_id=node_id, - config=resolved_node_data, + data=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -122,7 +122,7 @@ class MockNodeFactory(DifyNodeFactory): }: mock_instance = mock_class( node_id=node_id, - config=resolved_node_data, + data=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -132,7 +132,7 @@ class MockNodeFactory(DifyNodeFactory): else: mock_instance = mock_class( node_id=node_id, - config=resolved_node_data, + data=resolved_node_data, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index f9819c47ec..e0eb4e7361 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -56,7 +56,7 @@ class MockNodeMixin: def __init__( self, node_id: str, - config: Any, + data: Any, *, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", @@ -98,7 +98,7 @@ class MockNodeMixin: super().__init__( node_id=node_id, - config=config, + data=data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, **kwargs, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index ad8c0b2a04..60c6d0aaa6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -122,7 +122,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", @@ -151,7 +151,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} start_node = StartNode( node_id=start_config["id"], - config=StartNodeData(title="Start", variables=[]), + data=StartNodeData(title="Start", variables=[]), graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) @@ -173,7 +173,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_a_config = {"id": "human_a", "data": human_data.model_dump()} human_a = HumanInputNode( node_id=human_a_config["id"], - config=human_data, + data=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -183,7 +183,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor human_b_config = {"id": "human_b", "data": human_data.model_dump()} human_b = HumanInputNode( node_id=human_b_config["id"], - config=human_data, + data=human_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, @@ -205,7 +205,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor end_config = {"id": "end", "data": end_data.model_dump()} end_node = EndNode( node_id=end_config["id"], - config=end_data, + data=end_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index ae9dae0646..2603e29be6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -1,41 +1,36 @@ import time import uuid -from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory from core.workflow.system_variables import build_system_variables -from extensions.ext_database import db from graphon.enums import WorkflowNodeExecutionStatus -from graphon.graph import Graph from graphon.nodes.answer.answer_node import AnswerNode from graphon.nodes.answer.entities import AnswerNodeData from graphon.runtime import GraphRuntimeState, VariablePool from tests.workflow_test_utils import build_test_graph_init_params -def test_execute_answer(): +def _build_variable_pool() -> VariablePool: + return VariablePool.from_bootstrap( + system_variables=build_system_variables(user_id="aaa", files=[]), + user_inputs={}, + ) + + +def _build_answer_node(*, answer: str, variable_pool: VariablePool) -> AnswerNode: graph_config = { - "edges": [ - { - "id": "start-source-answer-target", - "source": "start", - "target": "answer", - }, - ], + "edges": [], "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, { "data": { - "title": "123", + "title": "Answer", "type": "answer", - "answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + "answer": answer, }, "id": "answer", - }, + } ], } - init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, @@ -46,42 +41,31 @@ def test_execute_answer(): invoke_from=InvokeFrom.DEBUGGER, call_depth=0, ) - - # construct variable pool - variable_pool = VariablePool( - system_variables=build_system_variables(user_id="aaa", files=[]), - user_inputs={}, - environment_variables=[], - conversation_variables=[], + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), ) - variable_pool.add(["start", "weather"], "sunny") - variable_pool.add(["llm", "text"], "You are a helpful AI.") - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # create node factory - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node = AnswerNode( + return AnswerNode( node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - config=AnswerNodeData( - title="123", + data=AnswerNodeData( + title="Answer", type="answer", - answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + answer=answer, ), ) - # Mock db.session.close() - db.session.close = MagicMock() - # execute node +def test_execute_answer_renders_variable_selectors() -> None: + variable_pool = _build_variable_pool() + variable_pool.add(["start", "weather"], "sunny") + variable_pool.add(["llm", "text"], "You are a helpful AI.") + node = _build_answer_node( + answer="Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.", + variable_pool=variable_pool, + ) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED @@ -89,36 +73,11 @@ def test_execute_answer(): def test_execute_answer_renders_structured_output_object_as_json() -> None: - init_params = build_test_graph_init_params( - workflow_id="1", - graph_config={"nodes": [], "edges": []}, - tenant_id="1", - app_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=build_system_variables(user_id="aaa", files=[]), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool() variable_pool.add(["1777539038857", "structured_output"], {"type": "greeting"}) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - node = AnswerNode( - node_id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=AnswerNodeData( - title="123", - type="answer", - answer="{{#1777539038857.structured_output#}}", - ), + node = _build_answer_node( + answer="{{#1777539038857.structured_output#}}", + variable_pool=variable_pool, ) result = node._run() @@ -128,35 +87,9 @@ def test_execute_answer_renders_structured_output_object_as_json() -> None: def test_execute_answer_falls_back_to_plain_selector_text_when_structured_output_missing() -> None: - init_params = build_test_graph_init_params( - workflow_id="1", - graph_config={"nodes": [], "edges": []}, - tenant_id="1", - app_id="1", - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables=build_system_variables(user_id="aaa", files=[]), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - node = AnswerNode( - node_id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=AnswerNodeData( - title="123", - type="answer", - answer="{{#1777539038857.structured_output#}}", - ), + node = _build_answer_node( + answer="{{#1777539038857.structured_output#}}", + variable_pool=_build_variable_pool(), ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index a18a36a099..235d56e989 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -81,7 +81,7 @@ def test_datasource_node_delegates_to_manager_stream(mocker: MockerFixture): node = DatasourceNode( node_id="n", - config=DatasourceNodeData( + data=DatasourceNodeData( type="datasource", version="1", title="Datasource", diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index be7cc073db..796fc7719d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -29,7 +29,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -85,7 +85,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -143,7 +143,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -230,7 +230,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -320,7 +320,7 @@ def test_init_headers(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -357,7 +357,7 @@ def test_init_params(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=default_system_variables()), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -390,7 +390,7 @@ def test_init_params(): def test_empty_api_key_raises_error_bearer(): """Test that empty API key raises AuthorizationConfigError for bearer auth.""" - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer(): def test_empty_api_key_raises_error_basic(): """Test that empty API key raises AuthorizationConfigError for basic auth.""" - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic(): def test_empty_api_key_raises_error_custom(): """Test that empty API key raises AuthorizationConfigError for custom auth.""" - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom(): def test_whitespace_only_api_key_raises_error(): """Test that whitespace-only API key raises AuthorizationConfigError.""" - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error(): def test_valid_api_key_works(): """Test that valid API key works correctly for bearer auth.""" - variable_pool = VariablePool(system_variables=default_system_variables()) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -536,7 +536,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable(): # UUID that triggers the json_repair truncation bug test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -583,7 +583,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): """ test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -624,7 +624,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): def test_executor_with_json_body_preserves_numbers_and_strings(): """Test that numbers are preserved and string values are properly quoted.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 2e89a2da3c..afde541beb 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -110,12 +110,15 @@ def _build_http_node( call_depth=0, ) graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(user_id="user", files=[]), + user_inputs={}, + ), start_at=time.perf_counter(), ) return HttpRequestNode( node_id="http-node", - config=HttpRequestNodeData.model_validate(node_data), + data=HttpRequestNodeData.model_validate(node_data), graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index fc7cdd64f9..902f7215c8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -33,6 +33,7 @@ from core.workflow.human_input_adapter import ( from core.workflow.node_runtime import DifyHumanInputNodeRuntime from core.workflow.system_variables import build_system_variables from graphon.entities import GraphInitParams +from graphon.file import File, FileTransferMethod, FileType from graphon.node_events import PauseRequestedEvent from graphon.node_events.node import StreamCompletedEvent from graphon.nodes.human_input.entities import ( @@ -53,6 +54,7 @@ from graphon.nodes.human_input.enums import ( ValueSourceType, ) from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.runtime import GraphRuntimeState, VariablePool from graphon.variables.segments import ArrayFileSegment, FileSegment, StringSegment from libs.datetime_utils import naive_utc_now @@ -141,6 +143,23 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): entity.status_value = HumanInputFormStatus.SUBMITTED +class _TestFileReferenceFactory(FileReferenceFactoryProtocol): + """Build graph-layer file objects without touching Dify persistence in unit tests.""" + + def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + remote_url=mapping.get("remote_url") or mapping.get("url"), + related_id=mapping.get("related_id") or mapping.get("upload_file_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + ) + + def _build_human_input_node( *, node_id: str, @@ -154,10 +173,11 @@ def _build_human_input_node( ) return HumanInputNode( node_id=node_id, - config=typed_node_data, + data=typed_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, runtime=runtime, + file_reference_factory=_TestFileReferenceFactory(), ) @@ -206,6 +226,7 @@ class TestFormInput: assert form_input.type == FormInputType.PARAGRAPH assert form_input.output_variable_name == "user_input" + assert form_input.default is not None assert form_input.default.type == ValueSourceType.CONSTANT assert form_input.default.value == "Enter your response here..." @@ -215,6 +236,7 @@ class TestFormInput: form_input = ParagraphInputConfig(output_variable_name="user_input", default=default) + assert form_input.default is not None assert form_input.default.type == ValueSourceType.VARIABLE assert form_input.default.selector == ["node_123", "output_var"] @@ -246,16 +268,16 @@ class TestUserAction: def test_user_action_length_boundaries(self): """Test user action id and title length boundaries.""" - action = UserActionConfig(id="a" * 20, title="b" * 20) + action = UserActionConfig(id="a" * 20, title="b" * 100) assert action.id == "a" * 20 - assert action.title == "b" * 20 + assert action.title == "b" * 100 @pytest.mark.parametrize( ("field_name", "value"), [ ("id", "a" * 21), - ("title", "b" * 21), + ("title", "b" * 101), ], ) def test_user_action_length_limits(self, field_name: str, value: str): @@ -431,7 +453,7 @@ class TestHumanInputNodeVariableResolution: """Tests for resolving variable-based defaults in HumanInputNode.""" def test_resolves_variable_defaults(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", @@ -506,7 +528,7 @@ class TestHumanInputNodeVariableResolution: assert params.resolved_default_values == expected_values def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", @@ -567,7 +589,7 @@ class TestHumanInputNodeVariableResolution: assert not hasattr(pause_event.reason, "form_token") def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", @@ -633,7 +655,7 @@ class TestHumanInputNodeVariableResolution: assert params.display_in_ui is True def test_debugger_debug_mode_overrides_email_recipients(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user-123", app_id="app", @@ -752,7 +774,7 @@ class TestHumanInputNodeRenderedContent: """Tests for rendering submitted content.""" def test_replaces_outputs_placeholders_after_submission(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", @@ -814,7 +836,7 @@ class TestHumanInputNodeRenderedContent: assert node_run_result.outputs["__rendered_content"] == StringSegment(value="Name: Alice") def test_resume_restores_file_outputs_as_runtime_segments(self): - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="user", app_id="app", diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index 71c1f113a2..b334dccd01 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -51,7 +51,7 @@ def _create_human_input_node( ) return HumanInputNode( node_id=config["id"], - config=node_data, + data=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, @@ -69,7 +69,11 @@ def _build_node( ) -> HumanInputNode: system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), + variable_pool=VariablePool.from_bootstrap( + system_variables=system_variables, + user_inputs={}, + environment_variables=[], + ), start_at=0.0, ) graph_init_params = GraphInitParams( @@ -149,7 +153,11 @@ def _build_node( def _build_timeout_node() -> HumanInputNode: system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), + variable_pool=VariablePool.from_bootstrap( + system_variables=system_variables, + user_inputs={}, + environment_variables=[], + ), start_at=0.0, ) graph_init_params = GraphInitParams( diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 8ffce39cd6..18ed7a0b1d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -32,7 +32,7 @@ class _MissingGraphBuilder: def _build_runtime_state() -> GraphRuntimeState: return GraphRuntimeState( - variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}), + variable_pool=VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}), start_at=0.0, ) @@ -46,7 +46,7 @@ def _build_iteration_node( init_params = build_test_graph_init_params(graph_config=graph_config) return IterationNode( node_id="iteration-node", - config=IterationNodeData( + data=IterationNodeData( type="iteration", title="Iteration", iterator_selector=["start", "items"], diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index 89433b34e6..0d760a2db7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -41,7 +41,7 @@ def mock_graph_init_params(): @pytest.fixture def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], @@ -103,7 +103,7 @@ def _build_node( ) -> KnowledgeIndexNode: return KnowledgeIndexNode( node_id=node_id, - config=( + data=( node_data if isinstance(node_data, KnowledgeIndexNodeData) else KnowledgeIndexNodeData.model_validate(node_data) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index d77a2ce363..3c821e75ba 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -47,7 +47,7 @@ def mock_graph_init_params(): @pytest.fixture def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], @@ -118,7 +118,7 @@ class TestKnowledgeRetrievalNode: # Act node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -147,7 +147,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -206,7 +206,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -250,7 +250,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -286,7 +286,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -321,7 +321,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -362,7 +362,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -401,7 +401,7 @@ class TestKnowledgeRetrievalNode: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -482,7 +482,7 @@ class TestFetchDatasetRetriever: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -519,7 +519,7 @@ class TestFetchDatasetRetriever: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -574,7 +574,7 @@ class TestFetchDatasetRetriever: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -622,7 +622,7 @@ class TestFetchDatasetRetriever: node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -683,7 +683,7 @@ class TestFetchDatasetRetriever: config = {"id": node_id, "data": node_data.model_dump()} node = KnowledgeRetrievalNode( node_id=node_id, - config=KnowledgeRetrievalNodeData.model_validate(config["data"]), + data=KnowledgeRetrievalNodeData.model_validate(config["data"]), graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index 388654f279..20b94d5d50 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -16,10 +16,10 @@ class TestListOperatorNode: """Comprehensive tests for ListOperatorNode.""" @staticmethod - def _build_node(*, config, graph_init_params, graph_runtime_state): + def _build_node(*, data, graph_init_params, graph_runtime_state): return ListOperatorNode( node_id="test", - config=config if isinstance(config, ListOperatorNodeData) else ListOperatorNodeData.model_validate(config), + data=data if isinstance(data, ListOperatorNodeData) else ListOperatorNodeData.model_validate(data), graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) @@ -65,7 +65,7 @@ class TestListOperatorNode: def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable return self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -83,7 +83,7 @@ class TestListOperatorNode: } node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -127,7 +127,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -153,7 +153,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -177,7 +177,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -201,7 +201,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -228,7 +228,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -255,7 +255,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -282,7 +282,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -312,7 +312,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -335,7 +335,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = None node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -359,7 +359,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -384,7 +384,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -408,7 +408,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -432,7 +432,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -456,7 +456,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) @@ -483,7 +483,7 @@ class TestListOperatorNode: mock_graph_runtime_state.variable_pool.get.return_value = mock_var node = self._build_node( - config=config, + data=config, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index 212ad07bd3..6a2fc81fef 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -613,7 +613,7 @@ def test_combine_message_content_with_role_handles_all_supported_roles(): SystemPromptMessage(content=contents) ) - with pytest.raises(NotImplementedError, match="Role custom is not supported"): + with pytest.raises(AssertionError, match="Expected code to be unreachable"): llm_utils.combine_message_content_with_role(contents=contents, role="custom") # type: ignore[arg-type] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index c09f2d3fb6..fb50723402 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -15,7 +15,7 @@ from core.app.llm.model_access import ( ) from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.workflow.system_variables import default_system_variables from graphon.entities import GraphInitParams @@ -187,7 +187,7 @@ def graph_init_params() -> GraphInitParams: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -208,7 +208,7 @@ def llm_node( http_client = mock.MagicMock() node = LLMNode( node_id="1", - config=llm_node_data, + data=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, @@ -241,9 +241,10 @@ def model_config(monkeypatch: pytest.MonkeyPatch): ) # Create actual provider and model type instances - model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test")) + model_assembly = create_plugin_model_assembly(tenant_id="test") + model_provider_factory = model_assembly.model_provider_factory provider_instance = model_provider_factory.get_model_provider("openai") - model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM) + model_type_instance = model_assembly.create_model_type_instance(provider="openai", model_type=ModelType.LLM) # Create a ProviderModelBundle provider_model_bundle = ProviderModelBundle( @@ -1173,7 +1174,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat http_client = mock.MagicMock() node = LLMNode( node_id="1", - config=llm_node_data, + data=llm_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 892f6cc586..dd57dde1fe 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -28,7 +28,7 @@ def _build_template_transform_node( ) return TemplateTransformNode( node_id=node_id, - config=typed_node_data, + data=typed_node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, **kwargs, diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py index a846efbb43..c25ac7da0f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -39,7 +39,7 @@ def mock_graph_runtime_state(): def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): node = TemplateTransformNode( node_id="test_node", - config=TemplateTransformNodeData( + data=TemplateTransformNodeData( title="Template Transform", type="template-transform", variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 364408ead6..a05151f79b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -35,7 +35,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(user_id="user", files=[]), + user_inputs={}, + ), start_at=0.0, ) return init_params, runtime_state @@ -62,7 +65,7 @@ def test_node_hydrates_data_during_initialization(): node = _SampleNode( node_id="node-1", - config=_build_node_data(), + data=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -82,13 +85,16 @@ def test_node_accepts_invoke_from_enum(): invoke_from=InvokeFrom.DEBUGGER, ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool.from_bootstrap( + system_variables=build_system_variables(user_id="user", files=[]), + user_inputs={}, + ), start_at=0.0, ) node = _SampleNode( node_id="node-1", - config=_build_node_data(), + data=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) @@ -140,7 +146,7 @@ def test_node_hydration_preserves_compatibility_extra_fields(): node = _SampleNode( node_id="node-1", - config=node_config["data"], + data=node_config["data"], graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index dd75b32593..4c67f3fb02 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -49,7 +49,7 @@ def document_extractor_node(graph_init_params): http_client = Mock() node = DocumentExtractorNode( node_id="test_node_id", - config=node_data, + data=node_data, graph_init_params=graph_init_params, graph_runtime_state=Mock(), http_client=http_client, @@ -186,12 +186,13 @@ def test_run_extract_text( monkeypatch.setattr("graphon.file.file_manager.download", mock_download) + dispatch_mock = None if mime_type == "application/pdf": - mock_pdf_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + dispatch_mock = Mock(return_value=expected_text[0]) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_file_extension", dispatch_mock) elif mime_type.startswith("application/vnd.openxmlformats"): - mock_docx_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) + dispatch_mock = Mock(return_value=expected_text[0]) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_by_mime_type", dispatch_mock) result = document_extractor_node._run() @@ -200,6 +201,19 @@ def test_run_extract_text( assert result.outputs is not None assert result.outputs["text"] == ArrayStringSegment(value=expected_text) + if mime_type == "application/pdf": + dispatch_mock.assert_called_once_with( + file_content=file_content, + file_extension=extension, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + elif mime_type.startswith("application/vnd.openxmlformats"): + dispatch_mock.assert_called_once_with( + file_content=file_content, + mime_type=mime_type, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + if transfer_method == FileTransferMethod.REMOTE_URL: document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt") elif transfer_method == FileTransferMethod.LOCAL_FILE: @@ -439,24 +453,42 @@ def test_extract_text_from_file_routes_excel_inputs(document_extractor_node, ext file.extension = extension file.mime_type = mime_type - with ( - patch( - "graphon.nodes.document_extractor.node._download_file_content", - return_value=b"excel", - ), - patch( - "graphon.nodes.document_extractor.node._extract_text_from_excel", - return_value="excel text", - ) as mock_extract, + with patch( + "graphon.nodes.document_extractor.node._download_file_content", + return_value=b"excel", ): - result = _extract_text_from_file( - document_extractor_node.http_client, - file, - unstructured_api_config=document_extractor_node._unstructured_api_config, - ) + if extension: + with patch( + "graphon.nodes.document_extractor.node._extract_text_by_file_extension", + return_value="excel text", + ) as mock_extract: + result = _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + mock_extract.assert_called_once_with( + file_content=b"excel", + file_extension=extension, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + else: + with patch( + "graphon.nodes.document_extractor.node._extract_text_by_mime_type", + return_value="excel text", + ) as mock_extract: + result = _extract_text_from_file( + document_extractor_node.http_client, + file, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) + mock_extract.assert_called_once_with( + file_content=b"excel", + mime_type=mime_type, + unstructured_api_config=document_extractor_node._unstructured_api_config, + ) assert result == "excel text" - mock_extract.assert_called_once_with(b"excel") def test_extract_text_from_file_rejects_missing_extension_and_mime_type(document_extractor_node): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index aa9a1360b0..5965645c4f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -29,7 +29,7 @@ def _build_if_else_node( node_id=str(uuid.uuid4()), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - config=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data), + data=node_data if isinstance(node_data, IfElseNodeData) else IfElseNodeData.model_validate(node_data), ) @@ -48,7 +48,10 @@ def test_execute_if_else_result_true(): ) # construct variable pool - pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}) + pool = VariablePool.from_bootstrap( + system_variables=build_system_variables(user_id="aaa", files=[]), + user_inputs={}, + ) pool.add(["start", "array_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ac", "def"]) pool.add(["start", "contains"], "cabcde") @@ -148,7 +151,7 @@ def test_execute_if_else_result_false(): ) # construct variable pool - pool = VariablePool( + pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], @@ -305,7 +308,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition): ) # construct variable pool with boolean values - pool = VariablePool( + pool = VariablePool.from_bootstrap( system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) @@ -359,7 +362,7 @@ def test_execute_if_else_boolean_false_conditions(): ) # construct variable pool with boolean values - pool = VariablePool( + pool = VariablePool.from_bootstrap( system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) @@ -424,7 +427,7 @@ def test_execute_if_else_boolean_cases_structure(): ) # construct variable pool with boolean values - pool = VariablePool( + pool = VariablePool.from_bootstrap( system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 465a4c0ff4..1b4cecc757 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -22,7 +22,7 @@ from graphon.variables import ArrayFileSegment def _build_list_operator_node(node_data: ListOperatorNodeData, graph_init_params) -> ListOperatorNode: return ListOperatorNode( node_id="test_node_id", - config=node_data, + data=node_data, graph_init_params=graph_init_params, graph_runtime_state=MagicMock(), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 5655f80737..f890f79511 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -31,7 +31,7 @@ def make_start_node(user_inputs, variables): return StartNode( node_id="start", - config=node_data, + data=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, @@ -260,7 +260,7 @@ def test_start_node_outputs_full_variable_pool_snapshot(): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node = StartNode( node_id="start", - config=node_data, + data=node_data, graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 284af68319..4d30746e5c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -15,16 +15,23 @@ from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage -from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.runtime import GraphRuntimeState from graphon.variables.segments import ArrayFileSegment -from tests.workflow_test_utils import build_test_graph_init_params +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool if TYPE_CHECKING: # pragma: no cover - imported for type checking only from graphon.nodes.tool.tool_node import ToolNode class _StubToolRuntime: - def get_runtime(self, *, node_id: str, node_data: Any, variable_pool: Any) -> ToolRuntimeHandle: + def get_runtime( + self, + *, + node_id: str, + node_data: Any, + variable_pool: Any, + node_execution_id: str | None = None, + ) -> ToolRuntimeHandle: raise NotImplementedError def get_runtime_parameters(self, *, tool_runtime: ToolRuntimeHandle) -> list[Any]: @@ -99,7 +106,7 @@ def tool_node(monkeypatch) -> ToolNode: call_depth=0, ) - variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id")) + variable_pool = build_test_variable_pool(variables=build_system_variables(user_id="user-id")) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] @@ -110,7 +117,7 @@ def tool_node(monkeypatch) -> ToolNode: node = ToolNode( node_id="node-instance", - config=ToolNodeData.model_validate(config["data"]), + data=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, @@ -227,3 +234,22 @@ def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): files_segment = completed_events[0].node_run_result.outputs["files"] assert isinstance(files_segment, ArrayFileSegment) assert files_segment.value == [file_obj] + + +def test_tool_node_passes_node_execution_id_when_runtime_accepts_it(tool_node: ToolNode): + runtime_handle = ToolRuntimeHandle(raw=object()) + tool_node._runtime.get_runtime = MagicMock(return_value=runtime_handle) + tool_node.ensure_execution_id = MagicMock(return_value="node-execution-id") + + result = tool_node._get_tool_runtime( + variable_pool=tool_node.graph_runtime_state.variable_pool, + node_execution_id="node-execution-id", + ) + + assert result is runtime_handle + tool_node._runtime.get_runtime.assert_called_once_with( + node_id="node-instance", + node_data=tool_node.node_data, + variable_pool=tool_node.graph_runtime_state.variable_pool, + node_execution_id="node-execution-id", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py index 438af211f3..aece73ce8c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -147,6 +147,69 @@ def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: Dify assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN +def test_get_runtime_stores_parent_trace_context_for_workflow_tools( + runtime: DifyToolNodeRuntime, +) -> None: + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables( + conversation_id="conversation-id", + workflow_execution_id="workflow-run-id", + ) + ) + workflow_runtime = MagicMock() + workflow_runtime.runtime.runtime_parameters = {} + node_data = ToolNodeData.model_validate( + { + "type": "tool", + "title": "Tool", + "provider_id": "provider", + "provider_type": ToolProviderType.WORKFLOW, + "provider_name": "provider", + "tool_name": "lookup", + "tool_label": "Lookup", + "tool_configurations": {}, + "tool_parameters": {}, + } + ) + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_runtime): + tool_runtime = runtime.get_runtime( + node_id="node-id", + node_data=node_data, + variable_pool=variable_pool, + node_execution_id="node-execution-id", + ) + + assert tool_runtime.raw.parent_trace_context.model_dump() == { + "parent_workflow_run_id": "workflow-run-id", + "parent_node_execution_id": "node-execution-id", + } + assert workflow_runtime.runtime.runtime_parameters == {} + + +def test_get_runtime_leaves_non_workflow_tool_runtime_parameters_unchanged( + runtime: DifyToolNodeRuntime, +) -> None: + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables( + conversation_id="conversation-id", + workflow_execution_id="workflow-run-id", + ) + ) + builtin_runtime = MagicMock() + builtin_runtime.runtime.runtime_parameters = {} + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=builtin_runtime): + runtime.get_runtime( + node_id="node-id", + node_data=_build_tool_node_data(), + variable_pool=variable_pool, + node_execution_id="node-execution-id", + ) + + assert builtin_runtime.runtime.runtime_parameters == {} + + def test_get_runtime_parameters_reads_required_flags(runtime: DifyToolNodeRuntime) -> None: tool_runtime = ToolRuntimeHandle( raw=SimpleNamespace( diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index e3b5e3b591..c5ac8d2ce2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -44,7 +44,7 @@ def test_trigger_event_node_run_populates_trigger_info_metadata() -> None: init_params, runtime_state = _build_context(graph_config={}) node = TriggerEventNode( node_id="node-1", - config=_build_node_data(), + data=_build_node_data(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 07d03bec05..fccb5ab1c3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -52,7 +52,7 @@ def create_webhook_node( node = TriggerWebhookNode( node_id="webhook-node-1", - config=webhook_data, + data=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index b839490d3c..c5ae542d8b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -44,7 +44,7 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) ) node = TriggerWebhookNode( node_id="1", - config=webhook_data, + data=webhook_data, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index e93a7c7ccd..d6159e84d4 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from types import SimpleNamespace from unittest.mock import MagicMock, patch, sentinel @@ -11,19 +12,20 @@ from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType from graphon.nodes.code.entities import CodeLanguage from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.llm.node import LLMNode from graphon.variables.segments import StringSegment -def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None: +def _assert_constructor_node_data(data, *, node_id: str, node_type: NodeType, version: str = "1") -> None: _ = node_id - if isinstance(config, BaseNodeData): - assert config.type == node_type - assert config.version == version + if isinstance(data, BaseNodeData): + assert data.type == node_type + assert data.version == version return - assert isinstance(config, dict) - assert config["type"] == node_type - assert config["version"] == version + assert isinstance(data, Mapping) + assert data["type"] == node_type + assert data.get("version", "1") == version def _node_constructor(*, return_value): @@ -470,7 +472,7 @@ class TestDifyNodeFactoryCreateNode: matched_node_class.assert_called_once() kwargs = matched_node_class.call_args.kwargs assert kwargs["node_id"] == "node-id" - _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") + _assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") assert kwargs["graph_init_params"] is sentinel.graph_init_params assert kwargs["graph_runtime_state"] is factory.graph_runtime_state latest_node_class.assert_not_called() @@ -492,7 +494,7 @@ class TestDifyNodeFactoryCreateNode: latest_node_class.assert_called_once() kwargs = latest_node_class.call_args.kwargs assert kwargs["node_id"] == "node-id" - _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") + _assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") assert kwargs["graph_init_params"] is sentinel.graph_init_params assert kwargs["graph_runtime_state"] is factory.graph_runtime_state @@ -530,7 +532,7 @@ class TestDifyNodeFactoryCreateNode: assert result is created_node kwargs = constructor.call_args.kwargs assert kwargs["node_id"] == "node-id" - _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type) + _assert_constructor_node_data(kwargs["data"], node_id="node-id", node_type=node_type) assert kwargs["graph_init_params"] is sentinel.graph_init_params assert kwargs["graph_runtime_state"] is factory.graph_runtime_state @@ -599,11 +601,12 @@ class TestDifyNodeFactoryCreateNode: prepared_llm.assert_called_once_with(sentinel.model_instance) assert kwargs["model_instance"] is wrapped_model_instance - def test_create_node_passes_alias_preserving_llm_config_to_constructor( - self, monkeypatch: pytest.MonkeyPatch, factory - ): + def test_create_node_passes_alias_preserving_llm_data_to_constructor(self, monkeypatch, factory): created_node = object() constructor = _node_constructor(return_value=created_node) + constructor.validate_node_data.side_effect = lambda node_data: LLMNodeData.model_validate( + node_data.model_dump(mode="python") if isinstance(node_data, BaseNodeData) else node_data + ) monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=constructor)) monkeypatch.setattr(factory, "_build_llm_compatible_node_init_kwargs", MagicMock(return_value={})) @@ -629,10 +632,56 @@ class TestDifyNodeFactoryCreateNode: factory.create_node(node_config) - config = constructor.call_args.kwargs["config"] - assert isinstance(config, dict) - assert config["structured_output_enabled"] is True - assert "structured_output_switch_on" not in config + data = constructor.call_args.kwargs["data"] + assert isinstance(data, Mapping) + assert data["structured_output_enabled"] is True + assert "structured_output_switch_on" not in data + assert LLMNodeData.model_validate(data).structured_output_enabled is True + + def test_create_node_preserves_structured_output_switch_after_graphon_constructor(self, monkeypatch, factory): + factory.graph_init_params = SimpleNamespace( + workflow_id="workflow-id", + graph_config={}, + run_context={}, + call_depth=0, + ) + monkeypatch.setattr(factory, "_resolve_node_class", MagicMock(return_value=LLMNode)) + monkeypatch.setattr( + factory, + "_build_llm_compatible_node_init_kwargs", + MagicMock( + return_value={ + "model_instance": sentinel.model_instance, + "llm_file_saver": sentinel.llm_file_saver, + "prompt_message_serializer": sentinel.prompt_message_serializer, + } + ), + ) + + node_config = { + "id": "llm-node-id", + "data": { + "type": BuiltinNodeTypes.LLM, + "title": "LLM", + "model": {"provider": "provider", "name": "model", "mode": "chat", "completion_params": {}}, + "prompt_template": [{"role": "system", "text": "x"}], + "context": {"enabled": False, "variable_selector": []}, + "vision": {"enabled": False}, + "structured_output_enabled": True, + "structured_output": { + "schema": { + "type": "object", + "properties": {"type": {"type": "string"}}, + "required": ["type"], + } + }, + }, + } + + node = factory.create_node(node_config) + + assert node.node_data.structured_output_switch_on is True + assert node.node_data.structured_output_enabled is True @pytest.mark.parametrize( ("node_type", "constructor_name", "expected_extra_kwargs"), @@ -711,7 +760,7 @@ class TestDifyNodeFactoryCreateNode: constructor_kwargs = constructor.call_args.kwargs assert constructor_kwargs["node_id"] == "node-id" - _assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type) + _assert_constructor_node_data(constructor_kwargs["data"], node_id="node-id", node_type=node_type) assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params assert constructor_kwargs["graph_runtime_state"] is factory.graph_runtime_state assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider diff --git a/api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py b/api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py index d18fc262ef..2dd3953d9a 100644 --- a/api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py +++ b/api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py @@ -7,6 +7,17 @@ from pathlib import Path def test_moved_core_nodes_resolve_after_importing_production_entrypoints(): api_root = Path(__file__).resolve().parents[4] + + # `PYTHONSAFEPATH=1` enables Python's safe-path mode, which suppresses the + # usual implicit insertion of the working directory into `sys.path`. + # Set `PYTHONPATH` explicitly so this subprocess test stays deterministic in + # both CI and local shells that may export `PYTHONSAFEPATH`. + env = os.environ.copy() + existing_pythonpath = env.get("PYTHONPATH") + env["PYTHONPATH"] = ( + str(api_root) if not existing_pythonpath else os.pathsep.join([str(api_root), existing_pythonpath]) + ) + env["PYTHONSAFEPATH"] = "1" script = textwrap.dedent( """ from core.app.apps import workflow_app_runner @@ -34,7 +45,7 @@ def test_moved_core_nodes_resolve_after_importing_production_entrypoints(): completed = subprocess.run( [sys.executable, "-c", script], cwd=api_root, - env=os.environ.copy(), + env=env, capture_output=True, text=True, check=False, diff --git a/api/tests/unit_tests/core/workflow/test_node_runtime.py b/api/tests/unit_tests/core/workflow/test_node_runtime.py index 874873d57e..f5e409cf3c 100644 --- a/api/tests/unit_tests/core/workflow/test_node_runtime.py +++ b/api/tests/unit_tests/core/workflow/test_node_runtime.py @@ -317,6 +317,81 @@ def test_dify_tool_file_manager_delegates_file_generator_lookup(monkeypatch: pyt get_file_generator.assert_called_once_with("tool-file-id") +def test_dify_tool_node_runtime_injects_outer_workflow_run_id_for_workflow_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + runtime_tool = SimpleNamespace(runtime=SimpleNamespace(runtime_parameters={})) + get_runtime = MagicMock(return_value=runtime_tool) + monkeypatch.setattr(node_runtime.ToolManager, "get_workflow_tool_runtime", get_runtime) + monkeypatch.setattr( + node_runtime, + "get_system_text", + lambda _pool, key: ( + "outer-workflow-run-id" if key == node_runtime.SystemVariableKey.WORKFLOW_EXECUTION_ID else None + ), + ) + + runtime = node_runtime.DifyToolNodeRuntime(_build_run_context()) + node_data = ToolNodeData( + title="Workflow Tool Node", + desc=None, + provider_id="workflow-provider-id", + provider_type=ToolProviderType.WORKFLOW, + provider_name="workflow-provider", + tool_name="workflow-tool", + tool_label="Workflow Tool", + tool_configurations={}, + tool_parameters={}, + ) + + handle = runtime.get_runtime( + node_id="tool-node", + node_data=node_data, + variable_pool=object(), + node_execution_id="node-execution-id", + ) + + assert handle.raw.tool is runtime_tool + assert handle.raw.parent_trace_context.model_dump() == { + "parent_workflow_run_id": "outer-workflow-run-id", + "parent_node_execution_id": "node-execution-id", + } + assert runtime_tool.runtime.runtime_parameters == {} + get_runtime.assert_called_once() + + +def test_dify_tool_node_runtime_does_not_inject_outer_workflow_run_id_for_non_workflow_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + runtime_tool = SimpleNamespace(runtime=SimpleNamespace(runtime_parameters={})) + get_runtime = MagicMock(return_value=runtime_tool) + monkeypatch.setattr(node_runtime.ToolManager, "get_workflow_tool_runtime", get_runtime) + monkeypatch.setattr(node_runtime, "get_system_text", lambda _pool, _key: None) + + runtime = node_runtime.DifyToolNodeRuntime(_build_run_context()) + node_data = ToolNodeData( + title="Builtin Tool Node", + desc=None, + provider_id="builtin-provider-id", + provider_type=ToolProviderType.BUILT_IN, + provider_name="builtin-provider", + tool_name="builtin-tool", + tool_label="Builtin Tool", + tool_configurations={}, + tool_parameters={}, + ) + + handle = runtime.get_runtime( + node_id="tool-node", + node_data=node_data, + variable_pool=object(), + ) + + assert handle.raw.tool is runtime_tool + assert "outer_workflow_run_id" not in runtime_tool.runtime.runtime_parameters + get_runtime.assert_called_once() + + def test_dify_human_input_runtime_builds_debug_repository(monkeypatch: pytest.MonkeyPatch) -> None: repository = MagicMock() repository_cls = MagicMock(return_value=repository) diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 9dab38ed8e..0017cd8d3f 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -109,8 +109,8 @@ class TestVariablePool: assert pool.get([ENVIRONMENT_VARIABLE_NODE_ID, "env_var_1"]) is not None assert pool.get([CONVERSATION_VARIABLE_NODE_ID, "conv_var_1"]) is not None - def test_constructor_loads_legacy_bootstrap_kwargs(self): - pool = VariablePool( + def test_from_bootstrap_loads_legacy_bootstrap_kwargs(self): + pool = VariablePool.from_bootstrap( system_variables=build_system_variables(user_id="test_user_id"), environment_variables=[StringVariable(name="env_var", value="env-value")], conversation_variables=[StringVariable(name="conv_var", value="conv-value")], diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 2e9e3468fd..661882f013 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -55,7 +55,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_to_variable_pool_with_system_variables(self): """Test mapping system variables from user inputs to variable pool.""" # Initialize variable pool with system variables - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="test_user_id", app_id="test_app_id", @@ -128,7 +128,7 @@ class TestWorkflowEntry: return NodeConfigDictAdapter.validate_python(node_config) workflow = StubWorkflow() - variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={}) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}) expected_limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, max_number=dify_config.CODE_MAX_NUMBER, @@ -157,7 +157,7 @@ class TestWorkflowEntry: """Test mapping environment variables from user inputs to variable pool.""" # Initialize variable pool with environment variables env_var = StringVariable(name="API_KEY", value="existing_key") - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), environment_variables=[env_var], user_inputs={}, @@ -198,7 +198,7 @@ class TestWorkflowEntry: """Test mapping conversation variables from user inputs to variable pool.""" # Initialize variable pool with conversation variables conv_var = StringVariable(name="last_message", value="Hello") - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), conversation_variables=[conv_var], user_inputs={}, @@ -239,7 +239,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_to_variable_pool_with_regular_variables(self): """Test mapping regular node variables from user inputs to variable pool.""" # Initialize empty variable pool - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -281,7 +281,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_file_handling(self): """Test mapping file inputs from user inputs to variable pool.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -340,7 +340,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_missing_variable_error(self): """Test that mapping raises error when required variable is missing.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -366,7 +366,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_alternative_key_format(self): """Test mapping with alternative key format (without node prefix).""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -396,7 +396,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_with_complex_selectors(self): """Test mapping with complex node variable keys.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -432,7 +432,7 @@ class TestWorkflowEntry: def test_mapping_user_inputs_invalid_node_variable(self): """Test that mapping handles invalid node variable format.""" - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=default_system_variables(), user_inputs={}, ) @@ -463,7 +463,7 @@ class TestWorkflowEntry: env_var = StringVariable(name="API_KEY", value="existing_key") conv_var = StringVariable(name="session_id", value="session123") - variable_pool = VariablePool( + variable_pool = VariablePool.from_bootstrap( system_variables=build_system_variables( user_id="test_user", app_id="test_app", diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 3978cbb1a0..a57cdd1337 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -7,7 +7,6 @@ import pytest from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.model_manager import ModelInstance from core.workflow import workflow_entry from core.workflow.system_variables import default_system_variables from graphon.entities.base_node_data import BaseNodeData @@ -16,10 +15,12 @@ from graphon.errors import WorkflowNodeRunFailedError from graphon.file import File, FileTransferMethod, FileType from graphon.graph import Graph from graphon.graph_events import GraphRunFailedEvent -from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMUsage from graphon.node_events import NodeRunResult from graphon.nodes import BuiltinNodeTypes from graphon.nodes.base.node import Node +from graphon.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from graphon.runtime import ChildGraphNotFoundError, VariablePool from graphon.variables.variables import StringVariable from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool @@ -29,9 +30,30 @@ def _build_typed_node_config(node_type: NodeType): return {"id": "node-id", "data": BaseNodeData(type=node_type)} -def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: - raw_model_instance = ModelInstance.__new__(ModelInstance) - return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance +def _build_model_config(*, provider: str = "openai", model_name: str = "gpt-4o") -> ModelConfig: + return ModelConfig(provider=provider, name=model_name, mode=LLMMode.CHAT) + + +def _build_llm_node_data(*, provider: str = "openai", model_name: str = "gpt-4o") -> LLMNodeData: + return LLMNodeData( + type=BuiltinNodeTypes.LLM, + title="Child Model", + model=_build_model_config(provider=provider, model_name=model_name), + prompt_template=[], + context=ContextConfig(enabled=False), + ) + + +def _build_question_classifier_node_data( + *, provider: str = "openai", model_name: str = "gpt-4o" +) -> QuestionClassifierNodeData: + return QuestionClassifierNodeData( + type=BuiltinNodeTypes.QUESTION_CLASSIFIER, + title="Child Model", + query_variable_selector=["sys", "query"], + model=_build_model_config(provider=provider, model_name=model_name), + classes=[], + ) class _FakeModelNodeMixin: @@ -40,22 +62,26 @@ class _FakeModelNodeMixin: return "1" def post_init(self) -> None: - self.model_instance, self.raw_model_instance = _build_wrapped_model_instance() + self.model_instance = SimpleNamespace(provider="stale-provider", model_name="stale-model") self.usage_snapshot = LLMUsage.empty_usage() self.usage_snapshot.total_tokens = 1 def _run(self) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + "model_provider": self.node_data.model.provider, + "model_name": self.node_data.model.name, + }, llm_usage=self.usage_snapshot, ) -class _FakeLLMNode(_FakeModelNodeMixin, Node[BaseNodeData]): +class _FakeLLMNode(_FakeModelNodeMixin, Node[LLMNodeData]): node_type = BuiltinNodeTypes.LLM -class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[BaseNodeData]): +class _FakeQuestionClassifierNode(_FakeModelNodeMixin, Node[QuestionClassifierNodeData]): node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER @@ -75,7 +101,7 @@ class TestWorkflowChildEngineBuilder: assert result is expected def test_build_child_engine_raises_when_root_node_is_missing(self): - builder = workflow_entry._WorkflowChildEngineBuilder() + builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id") graph_init_params = SimpleNamespace(graph_config={"nodes": []}) parent_graph_runtime_state = SimpleNamespace( execution_context=sentinel.execution_context, @@ -92,7 +118,7 @@ class TestWorkflowChildEngineBuilder: ) def test_build_child_engine_constructs_graph_engine_with_quota_layer_only(self): - builder = workflow_entry._WorkflowChildEngineBuilder() + builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id") graph_init_params = SimpleNamespace(graph_config={"nodes": [{"id": "root"}]}) parent_graph_runtime_state = SimpleNamespace( execution_context=sentinel.execution_context, @@ -114,7 +140,7 @@ class TestWorkflowChildEngineBuilder: patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls, patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), - patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer), + patch.object(workflow_entry, "LLMQuotaLayer", return_value=sentinel.llm_quota_layer) as llm_quota_layer_cls, ): result = builder.build_child_engine( workflow_id="workflow-id", @@ -147,11 +173,12 @@ class TestWorkflowChildEngineBuilder: config=sentinel.graph_engine_config, child_engine_builder=builder, ) + llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id") assert child_engine.layer.call_args_list == [((sentinel.llm_quota_layer,), {})] @pytest.mark.parametrize("node_cls", [_FakeLLMNode, _FakeQuestionClassifierNode]) def test_build_child_engine_runs_llm_quota_layer_for_child_model_nodes(self, node_cls): - builder = workflow_entry._WorkflowChildEngineBuilder() + builder = workflow_entry._WorkflowChildEngineBuilder(tenant_id="tenant-id") graph_init_params = build_test_graph_init_params( graph_config={"nodes": [{"id": "root"}], "edges": []}, ) @@ -163,12 +190,10 @@ class TestWorkflowChildEngineBuilder: def build_graph(*, graph_config, node_factory, root_node_id): _ = graph_config + node_data = _build_llm_node_data() if node_cls is _FakeLLMNode else _build_question_classifier_node_data() node = node_cls( node_id=root_node_id, - config=BaseNodeData( - type=node_cls.node_type, - title="Child Model", - ), + data=node_data, graph_init_params=node_factory.graph_init_params, graph_runtime_state=node_factory.graph_runtime_state, ) @@ -191,8 +216,8 @@ class TestWorkflowChildEngineBuilder: ), ), patch.object(workflow_entry.Graph, "init", side_effect=build_graph), - patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available") as ensure_quota, - patch("core.app.workflow.layers.llm_quota.deduct_llm_quota") as deduct_quota, + patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available_for_model") as ensure_quota, + patch("core.app.workflow.layers.llm_quota.deduct_llm_quota_for_model") as deduct_quota, ): child_engine = builder.build_child_engine( workflow_id="workflow-id", @@ -203,10 +228,15 @@ class TestWorkflowChildEngineBuilder: list(child_engine.run()) node = created_node["node"] - ensure_quota.assert_called_once_with(model_instance=node.raw_model_instance) + ensure_quota.assert_called_once_with( + tenant_id="tenant-id", + provider=node.node_data.model.provider, + model=node.node_data.model.name, + ) deduct_quota.assert_called_once_with( - tenant_id="tenant", - model_instance=node.raw_model_instance, + tenant_id="tenant-id", + provider=node.node_data.model.provider, + model=node.node_data.model.name, usage=node.usage_snapshot, ) @@ -252,7 +282,7 @@ class TestWorkflowEntryInit: "ExecutionLimitsLayer", return_value=execution_limits_layer, ) as execution_limits_layer_cls, - patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer), + patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer) as llm_quota_layer_cls, patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer), ): entry = workflow_entry.WorkflowEntry( @@ -291,6 +321,7 @@ class TestWorkflowEntryInit: max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME, ) + llm_quota_layer_cls.assert_called_once_with(tenant_id="tenant-id") assert graph_engine.layer.call_args_list == [ ((debug_layer,), {}), ((execution_limits_layer,), {}), @@ -334,7 +365,7 @@ class TestWorkflowEntrySingleStepRun: def extract_variable_selector_to_variable_mapping(**_kwargs): return {} - variable_pool = VariablePool(system_variables=default_system_variables(), user_inputs={}) + variable_pool = VariablePool.from_bootstrap(system_variables=default_system_variables(), user_inputs={}) variable_loader = MagicMock() variable_loader.load_variables.return_value = [ StringVariable( diff --git a/api/tests/unit_tests/events/test_update_provider_when_message_created.py b/api/tests/unit_tests/events/test_update_provider_when_message_created.py new file mode 100644 index 0000000000..9cb8ca7854 --- /dev/null +++ b/api/tests/unit_tests/events/test_update_provider_when_message_created.py @@ -0,0 +1,130 @@ +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +from sqlalchemy import create_engine, select + +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.entities.provider_entities import ProviderQuotaType, QuotaUnit +from events.event_handlers import update_provider_when_message_created +from models import TenantCreditPool +from models.provider import ProviderType + + +def test_message_created_trial_credit_accounting_does_not_raise_when_balance_is_insufficient() -> None: + engine = create_engine("sqlite:///:memory:") + TenantCreditPool.__table__.create(engine) + tenant_id = str(uuid4()) + pool_id = str(uuid4()) + with engine.begin() as connection: + connection.execute( + TenantCreditPool.__table__.insert(), + { + "id": pool_id, + "tenant_id": tenant_id, + "pool_type": ProviderQuotaType.TRIAL, + "quota_limit": 10, + "quota_used": 9, + }, + ) + + system_configuration = SimpleNamespace( + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=10, + ) + ], + ) + application_generate_entity = ChatAppGenerateEntity.model_construct( + app_config=SimpleNamespace(tenant_id=tenant_id), + model_conf=SimpleNamespace( + provider="openai", + model="gpt-4o", + provider_model_bundle=SimpleNamespace( + configuration=SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=system_configuration, + ) + ), + ), + ) + message = SimpleNamespace(message_tokens=2, answer_tokens=1) + + with ( + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + patch.object(update_provider_when_message_created, "_execute_provider_updates"), + ): + update_provider_when_message_created.handle( + sender=message, + application_generate_entity=application_generate_entity, + ) + + with engine.connect() as connection: + quota_used = connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id)) + + assert quota_used == 10 + + +def test_message_created_paid_credit_accounting_uses_paid_pool() -> None: + tenant_id = str(uuid4()) + system_configuration = SimpleNamespace( + current_quota_type=ProviderQuotaType.PAID, + quota_configurations=[ + SimpleNamespace( + quota_type=ProviderQuotaType.PAID, + quota_unit=QuotaUnit.TOKENS, + quota_limit=10, + ) + ], + ) + application_generate_entity = ChatAppGenerateEntity.model_construct( + app_config=SimpleNamespace(tenant_id=tenant_id), + model_conf=SimpleNamespace( + provider="openai", + model="gpt-4o", + provider_model_bundle=SimpleNamespace( + configuration=SimpleNamespace( + using_provider_type=ProviderType.SYSTEM, + system_configuration=system_configuration, + ) + ), + ), + ) + message = SimpleNamespace(message_tokens=2, answer_tokens=1) + + with ( + patch.object(update_provider_when_message_created, "_deduct_credit_pool_quota_capped") as mock_deduct, + patch.object(update_provider_when_message_created, "_execute_provider_updates"), + ): + update_provider_when_message_created.handle( + sender=message, + application_generate_entity=application_generate_entity, + ) + + mock_deduct.assert_called_once_with( + tenant_id=tenant_id, + credits_required=3, + pool_type="paid", + ) + + +def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted(caplog) -> None: + with patch( + "services.credit_pool_service.CreditPoolService.deduct_credits_capped", + return_value=3, + ) as mock_deduct: + update_provider_when_message_created._deduct_credit_pool_quota_capped( + tenant_id="tenant-id", + credits_required=3, + pool_type="trial", + ) + + mock_deduct.assert_called_once_with( + tenant_id="tenant-id", + credits_required=3, + pool_type="trial", + ) + assert "Credit pool exhausted during message-created accounting" not in caplog.text diff --git a/api/tests/unit_tests/extensions/test_set_secretkey.py b/api/tests/unit_tests/extensions/test_set_secretkey.py new file mode 100644 index 0000000000..8a8e4e2b19 --- /dev/null +++ b/api/tests/unit_tests/extensions/test_set_secretkey.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import pytest +from flask import Flask + +from extensions import ext_set_secretkey + + +class InMemoryStorage: + def __init__(self, files: dict[str, bytes] | None = None) -> None: + self.files = files or {} + self.saved_files: list[tuple[str, bytes]] = [] + + def load_once(self, filename: str) -> bytes: + try: + return self.files[filename] + except KeyError: + raise FileNotFoundError(filename) + + def save(self, filename: str, data: bytes) -> None: + self.files[filename] = data + self.saved_files.append((filename, data)) + + +def test_init_app_uses_configured_secret_key(monkeypatch: pytest.MonkeyPatch) -> None: + secret_key = "configured-secret-key" + storage = InMemoryStorage() + monkeypatch.setattr("extensions.ext_set_secretkey.dify_config.SECRET_KEY", secret_key) + monkeypatch.setattr("configs.secret_key.storage", storage) + app = Flask(__name__) + app.config["SECRET_KEY"] = secret_key + + ext_set_secretkey.init_app(app) + + assert app.secret_key == secret_key + assert app.config["SECRET_KEY"] == secret_key + assert storage.saved_files == [] + + +def test_init_app_generates_and_persists_secret_key_when_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + monkeypatch.setattr("extensions.ext_set_secretkey.dify_config.SECRET_KEY", "") + monkeypatch.setattr("configs.secret_key.storage", storage) + app = Flask(__name__) + app.config["SECRET_KEY"] = "" + + ext_set_secretkey.init_app(app) + + persisted_key = storage.files[".dify_secret_key"].decode("utf-8").strip() + assert persisted_key + assert storage.saved_files == [(".dify_secret_key", f"{persisted_key}\n".encode())] + assert persisted_key == ext_set_secretkey.dify_config.SECRET_KEY + assert persisted_key == app.config["SECRET_KEY"] + assert persisted_key == app.secret_key + + +def test_init_app_reuses_persisted_secret_key_when_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + persisted_key = "persisted-secret-key" + storage = InMemoryStorage({".dify_secret_key": f"{persisted_key}\n".encode()}) + monkeypatch.setattr("extensions.ext_set_secretkey.dify_config.SECRET_KEY", "") + monkeypatch.setattr("configs.secret_key.storage", storage) + app = Flask(__name__) + app.config["SECRET_KEY"] = "" + + ext_set_secretkey.init_app(app) + + assert persisted_key == ext_set_secretkey.dify_config.SECRET_KEY + assert persisted_key == app.config["SECRET_KEY"] + assert persisted_key == app.secret_key + assert storage.saved_files == [] diff --git a/api/tests/unit_tests/factories/test_file_validation.py b/api/tests/unit_tests/factories/test_file_validation.py new file mode 100644 index 0000000000..61337fcf10 --- /dev/null +++ b/api/tests/unit_tests/factories/test_file_validation.py @@ -0,0 +1,159 @@ +"""Unit tests for is_file_valid_with_config.""" + +from __future__ import annotations + +import pytest + +from factories.file_factory.validation import is_file_valid_with_config +from graphon.file import FileTransferMethod, FileType, FileUploadConfig + + +def _validate( + *, + input_file_type: str, + file_extension: str = ".png", + file_transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE, + config: FileUploadConfig, +) -> bool: + return is_file_valid_with_config( + input_file_type=input_file_type, + file_extension=file_extension, + file_transfer_method=file_transfer_method, + config=config, + ) + + +@pytest.mark.parametrize( + ("input_file_type", "file_extension", "allowed_file_types", "allowed_file_extensions", "expected"), + [ + # round-1 happy path: literal "custom" mapping, ext whitelisted + ("custom", ".png", [FileType.CUSTOM], [".png"], True), + # round-2 replay: MessageFile.type is the resolved type, but config still allows CUSTOM + ("image", ".png", [FileType.CUSTOM], [".png"], True), + ("document", ".pdf", [FileType.CUSTOM], [".pdf"], True), + # mixed bucket [IMAGE, CUSTOM]: document falls into CUSTOM bucket via extension + ("document", ".pdf", [FileType.IMAGE, FileType.CUSTOM], [".pdf"], True), + ("document", ".exe", [FileType.IMAGE, FileType.CUSTOM], [".pdf"], False), + ("image", ".jpg", [FileType.IMAGE], [], True), + ("video", ".mp4", [FileType.IMAGE, FileType.DOCUMENT], [], False), + ("custom", ".exe", [FileType.CUSTOM], [".png"], False), + # empty allowed_file_types == no type restriction + ("video", ".mp4", [], [], True), + ], +) +def test_bucket_semantics(input_file_type, file_extension, allowed_file_types, allowed_file_extensions, expected): + config = FileUploadConfig( + allowed_file_types=allowed_file_types, + allowed_file_extensions=allowed_file_extensions, + ) + assert _validate(input_file_type=input_file_type, file_extension=file_extension, config=config) is expected + + +@pytest.mark.parametrize("whitelist_entry", [".png", ".PNG", "png", "PNG", " .Png ", "PnG"]) +def test_extension_match_is_case_and_dot_insensitive(whitelist_entry): + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[whitelist_entry], + ) + assert _validate(input_file_type="custom", file_extension=".png", config=config) is True + + +def test_extension_mismatch_still_rejected_after_normalization(): + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[".png", ".jpg"], + ) + assert _validate(input_file_type="custom", file_extension=".pdf", config=config) is False + + +def test_mixed_case_whitelist_replicating_real_user_config(): + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[".PNG", "png", "JPG", ".WEBP", "SVG", "GIF"], + ) + for ext in (".png", ".jpg", ".webp", ".svg", ".gif"): + assert _validate(input_file_type="custom", file_extension=ext, config=config) is True + + +def test_tool_file_always_passes(): + config = FileUploadConfig(allowed_file_types=[FileType.CUSTOM], allowed_file_extensions=[".pdf"]) + assert ( + _validate( + input_file_type="image", + file_extension=".png", + file_transfer_method=FileTransferMethod.TOOL_FILE, + config=config, + ) + is True + ) + + +def test_transfer_method_gate_for_non_image(): + config = FileUploadConfig( + allowed_file_types=[FileType.DOCUMENT], + allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE], + ) + assert ( + _validate( + input_file_type="document", + file_extension=".pdf", + file_transfer_method=FileTransferMethod.LOCAL_FILE, + config=config, + ) + is True + ) + assert ( + _validate( + input_file_type="document", + file_extension=".pdf", + file_transfer_method=FileTransferMethod.REMOTE_URL, + config=config, + ) + is False + ) + + +def test_history_replay_matches_round_1_outcome_under_unchanged_config(): + """A file that passes round 1 must pass history replay when config is unchanged.""" + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[".png"], + ) + assert _validate(input_file_type="custom", file_extension=".png", config=config) is True + assert _validate(input_file_type="image", file_extension=".png", config=config) is True + + +def test_empty_whitelist_in_custom_bucket_denies_by_default(): + """Defensive: when a file lands in the CUSTOM bucket, an empty + allowed_file_extensions list rejects. The UI never submits empty; + this guards DSL / API paths that bypass the UI from accidentally + widening what's accepted.""" + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[], + ) + assert _validate(input_file_type="custom", file_extension=".png", config=config) is False + assert _validate(input_file_type="image", file_extension=".png", config=config) is False + + +def test_normalize_handles_whitespace_and_empty_consistently(): + """Whitespace-only or empty entries in the whitelist must not match real + extensions (regression guard for _normalize_extension edge cases).""" + for noisy_entry in ("", " ", "\t"): + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=[noisy_entry], + ) + assert _validate(input_file_type="custom", file_extension=".png", config=config) is False + + +def test_empty_extension_does_not_spuriously_match_empty_whitelist_entry(): + """Defensive: even if the whitelist contains an empty / whitespace entry + (e.g., a stray comma in DSL), an extensionless file must not pass via + a both-sides-empty match. Real entries in the same whitelist still match.""" + config = FileUploadConfig( + allowed_file_types=[FileType.CUSTOM], + allowed_file_extensions=["", ".png"], + ) + assert _validate(input_file_type="custom", file_extension=".png", config=config) is True + assert _validate(input_file_type="custom", file_extension="", config=config) is False diff --git a/api/tests/unit_tests/libs/test_passport.py b/api/tests/unit_tests/libs/test_passport.py index f33484c18d..90b58ae548 100644 --- a/api/tests/unit_tests/libs/test_passport.py +++ b/api/tests/unit_tests/libs/test_passport.py @@ -143,28 +143,13 @@ class TestPassportService: assert str(exc_info.value) == "401 Unauthorized: Token has expired." # Configuration tests - def test_should_handle_empty_secret_key(self): - """Test behavior when SECRET_KEY is empty""" + def test_should_use_configured_secret_key_without_policy_validation(self): + """Test that policy decisions are owned by config, not PassportService.""" with patch("libs.passport.dify_config") as mock_config: - mock_config.SECRET_KEY = "" + mock_config.SECRET_KEY = "configured" service = PassportService() - # Empty secret key should still work but is insecure - payload = {"test": "data"} - token = service.issue(payload) - decoded = service.verify(token) - assert decoded == payload - - def test_should_handle_none_secret_key(self): - """Test behavior when SECRET_KEY is None""" - with patch("libs.passport.dify_config") as mock_config: - mock_config.SECRET_KEY = None - service = PassportService() - - payload = {"test": "data"} - # JWT library will raise TypeError when secret is None - with pytest.raises((TypeError, jwt.exceptions.InvalidKeyError)): - service.issue(payload) + assert service.sk == "configured" # Boundary condition tests def test_should_handle_large_payload(self, passport_service): diff --git a/api/tests/unit_tests/models/test_comment_models.py b/api/tests/unit_tests/models/test_comment_models.py index 277335cbef..8c8985aff8 100644 --- a/api/tests/unit_tests/models/test_comment_models.py +++ b/api/tests/unit_tests/models/test_comment_models.py @@ -4,7 +4,15 @@ from models.comment import WorkflowComment, WorkflowCommentMention, WorkflowComm def test_workflow_comment_account_properties_and_cache() -> None: - comment = WorkflowComment(created_by="user-1", resolved_by="user-2", content="hello", position_x=1, position_y=2) + comment = WorkflowComment( + created_by="user-1", + resolved_by="user-2", + content="hello", + position_x=1, + position_y=2, + tenant_id="xxx", + app_id="yyy", + ) created_account = Mock(id="user-1") resolved_account = Mock(id="user-2") @@ -21,6 +29,8 @@ def test_workflow_comment_account_properties_and_cache() -> None: get_mock.assert_not_called() comment_without_resolver = WorkflowComment( + tenant_id="xxx", + app_id="yyy", created_by="user-1", resolved_by=None, content="hello", @@ -37,7 +47,15 @@ def test_workflow_comment_counts_and_participants() -> None: reply_2 = WorkflowCommentReply(comment_id="comment-1", content="reply-2", created_by="user-2") mention_1 = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-3") mention_2 = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-4") - comment = WorkflowComment(created_by="user-1", resolved_by=None, content="hello", position_x=1, position_y=2) + comment = WorkflowComment( + created_by="user-1", + resolved_by=None, + content="hello", + position_x=1, + position_y=2, + tenant_id="xxx", + app_id="yyy", + ) comment.replies = [reply_1, reply_2] comment.mentions = [mention_1, mention_2] @@ -63,7 +81,15 @@ def test_workflow_comment_counts_and_participants() -> None: def test_workflow_comment_participants_use_cached_accounts() -> None: reply = WorkflowCommentReply(comment_id="comment-1", content="reply-1", created_by="user-2") mention = WorkflowCommentMention(comment_id="comment-1", mentioned_user_id="user-3") - comment = WorkflowComment(created_by="user-1", resolved_by=None, content="hello", position_x=1, position_y=2) + comment = WorkflowComment( + created_by="user-1", + resolved_by=None, + content="hello", + position_x=1, + position_y=2, + tenant_id="xxx", + app_id="yyy", + ) comment.replies = [reply] comment.mentions = [mention] diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 51d95c4239..f4ccfb4191 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -12,7 +12,9 @@ This test suite covers: import json import pickle from datetime import UTC, datetime +from types import SimpleNamespace from unittest.mock import Mock, patch +from urllib.parse import parse_qs, urlparse from uuid import uuid4 from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -676,6 +678,51 @@ class TestDocumentSegmentIndexing: # Assert assert segment.hit_count == 5 + def test_document_segment_attachments_prefers_files_url_for_source_url(self, monkeypatch): + """Test attachment source URLs use FILES_URL before falling back to CONSOLE_API_URL.""" + # Arrange + segment = DocumentSegment( + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="document-1", + position=1, + content="Test", + word_count=1, + tokens=2, + created_by="user-1", + ) + segment.id = "segment-1" + attachment = SimpleNamespace( + id="upload-1", + name="image.png", + size=128, + extension="png", + mime_type="image/png", + ) + + monkeypatch.setattr("models.dataset.time.time", lambda: 1700000000) + monkeypatch.setattr("models.dataset.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("models.dataset.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("models.dataset.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("models.dataset.dify_config.CONSOLE_API_URL", "https://console.example.com") + + with patch("models.dataset.db") as mock_db: + mock_db.session.execute.return_value.all.return_value = [(Mock(), attachment)] + + # Act + attachments = segment.attachments + + # Assert + assert len(attachments) == 1 + source_url = attachments[0]["source_url"] + parsed = urlparse(source_url) + query = parse_qs(parsed.query) + assert parsed.netloc == "files.example.com" + assert parsed.path == "/files/upload-1/image-preview" + assert query["timestamp"] == ["1700000000"] + assert query["nonce"] == ["01010101010101010101010101010101"] + assert query["sign"][0] + def test_document_segment_error_tracking(self): """Test document segment error tracking.""" # Arrange @@ -800,9 +847,7 @@ class TestDatasetProcessRule: # Act process_rule = DatasetProcessRule( - dataset_id=dataset_id, - mode=ProcessRuleMode.AUTOMATIC, - created_by=created_by, + dataset_id=dataset_id, mode=ProcessRuleMode.AUTOMATIC, created_by=created_by, rules=None ) # Assert diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 83258fd1b7..5d148974f8 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -173,7 +173,8 @@ class AudioServiceTestDataFactory: file = Mock(spec=FileStorage) file.filename = filename file.mimetype = mimetype - file.read = Mock(return_value=content) + file.stream = Mock() + file.stream.read = Mock(return_value=content) for key, value in kwargs.items(): setattr(file, key, value) return file @@ -216,7 +217,7 @@ class TestAudioServiceASR: """Test speech-to-text (ASR) operations.""" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory): + def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test successful ASR transcription in CHAT mode.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -241,7 +242,9 @@ class TestAudioServiceASR: mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123") @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory): + def test_transcript_asr_success_advanced_chat_mode( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test successful ASR transcription in ADVANCED_CHAT mode.""" # Arrange workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": True}}) @@ -263,7 +266,7 @@ class TestAudioServiceASR: # Assert assert result == {"text": "Workflow transcribed text"} - def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory): + def test_transcript_asr_raises_error_when_feature_disabled_chat_mode(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when speech-to-text is disabled in CHAT mode.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": False}) @@ -277,7 +280,9 @@ class TestAudioServiceASR: with pytest.raises(ValueError, match="Speech to text is not enabled"): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode(self, factory): + def test_transcript_asr_raises_error_when_feature_disabled_workflow_mode( + self, factory: AudioServiceTestDataFactory + ): """Test that ASR raises error when speech-to-text is disabled in WORKFLOW mode.""" # Arrange workflow = factory.create_workflow_mock(features_dict={"speech_to_text": {"enabled": False}}) @@ -291,7 +296,7 @@ class TestAudioServiceASR: with pytest.raises(ValueError, match="Speech to text is not enabled"): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_when_workflow_missing(self, factory): + def test_transcript_asr_raises_error_when_workflow_missing(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when workflow is missing in WORKFLOW mode.""" # Arrange app = factory.create_app_mock( @@ -304,7 +309,7 @@ class TestAudioServiceASR: with pytest.raises(ValueError, match="Speech to text is not enabled"): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory): + def test_transcript_asr_raises_error_when_no_file_uploaded(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when no file is uploaded.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -317,7 +322,7 @@ class TestAudioServiceASR: with pytest.raises(NoAudioUploadedServiceError): AudioService.transcript_asr(app_model=app, file=None) - def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory): + def test_transcript_asr_raises_error_for_unsupported_audio_type(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error for unsupported audio file types.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -331,7 +336,7 @@ class TestAudioServiceASR: with pytest.raises(UnsupportedAudioTypeServiceError): AudioService.transcript_asr(app_model=app, file=file) - def test_transcript_asr_raises_error_for_large_file(self, factory): + def test_transcript_asr_raises_error_for_large_file(self, factory: AudioServiceTestDataFactory): """Test that ASR raises error when file exceeds size limit (30MB).""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -348,7 +353,9 @@ class TestAudioServiceASR: AudioService.transcript_asr(app_model=app, file=file) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): + def test_transcript_asr_raises_error_when_no_model_instance( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that ASR raises error when no model instance is available.""" # Arrange app_model_config = factory.create_app_model_config_mock(speech_to_text_dict={"enabled": True}) @@ -371,7 +378,7 @@ class TestAudioServiceTTS: """Test text-to-speech (TTS) operations.""" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory): + def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test successful TTS with text input.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -405,7 +412,7 @@ class TestAudioServiceTTS: ) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory): + def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test TTS uses default voice when none specified.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -435,7 +442,9 @@ class TestAudioServiceTTS: assert call_args.kwargs["voice"] == "default-voice" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory): + def test_transcript_tts_gets_first_available_voice_when_none_configured( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test TTS gets first available voice when none is configured.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -467,7 +476,7 @@ class TestAudioServiceTTS: @patch("services.audio_service.WorkflowService", autospec=True) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_workflow_mode_with_draft( - self, mock_model_manager_class, mock_workflow_service_class, factory + self, mock_model_manager_class, mock_workflow_service_class, factory: AudioServiceTestDataFactory ): """Test TTS in WORKFLOW mode with draft workflow.""" # Arrange @@ -499,7 +508,7 @@ class TestAudioServiceTTS: assert result == b"draft audio" mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app) - def test_transcript_tts_raises_error_when_text_missing(self, factory): + def test_transcript_tts_raises_error_when_text_missing(self, factory: AudioServiceTestDataFactory): """Test that TTS raises error when text is missing.""" # Arrange app = factory.create_app_mock() @@ -509,7 +518,9 @@ class TestAudioServiceTTS: AudioService.transcript_tts(app_model=app, text=None) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory): + def test_transcript_tts_raises_error_when_no_voices_available( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that TTS raises error when no voices are available.""" # Arrange app_model_config = factory.create_app_model_config_mock( @@ -535,7 +546,7 @@ class TestAudioServiceTTSVoices: """Test TTS voice listing operations.""" @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_voices_success(self, mock_model_manager_class, factory): + def test_transcript_tts_voices_success(self, mock_model_manager_class, factory: AudioServiceTestDataFactory): """Test successful retrieval of TTS voices.""" # Arrange tenant_id = "tenant-123" @@ -560,7 +571,9 @@ class TestAudioServiceTTSVoices: mock_model_instance.get_tts_voices.assert_called_once_with(language) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): + def test_transcript_tts_voices_raises_error_when_no_model_instance( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that TTS voices raises error when no model instance is available.""" # Arrange tenant_id = "tenant-123" @@ -575,7 +588,9 @@ class TestAudioServiceTTSVoices: AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) - def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory): + def test_transcript_tts_voices_propagates_exceptions( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): """Test that TTS voices propagates exceptions from model instance.""" # Arrange tenant_id = "tenant-123" diff --git a/api/tests/unit_tests/services/test_credit_pool_service.py b/api/tests/unit_tests/services/test_credit_pool_service.py new file mode 100644 index 0000000000..e77ef894e7 --- /dev/null +++ b/api/tests/unit_tests/services/test_credit_pool_service.py @@ -0,0 +1,158 @@ +from types import SimpleNamespace +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy import create_engine, select +from sqlalchemy.engine import Engine + +from core.errors.error import QuotaExceededError +from models import TenantCreditPool +from models.enums import ProviderQuotaType +from services.credit_pool_service import CreditPoolService + + +def _create_engine_with_pool(*, quota_limit: int, quota_used: int) -> tuple[Engine, str, str]: + engine = create_engine("sqlite:///:memory:") + TenantCreditPool.__table__.create(engine) + tenant_id = str(uuid4()) + pool_id = str(uuid4()) + with engine.begin() as connection: + connection.execute( + TenantCreditPool.__table__.insert(), + { + "id": pool_id, + "tenant_id": tenant_id, + "pool_type": ProviderQuotaType.TRIAL, + "quota_limit": quota_limit, + "quota_used": quota_used, + }, + ) + return engine, tenant_id, pool_id + + +def _get_quota_used(*, engine: Engine, pool_id: str) -> int | None: + with engine.connect() as connection: + return connection.scalar(select(TenantCreditPool.quota_used).where(TenantCreditPool.id == pool_id)) + + +def test_check_and_deduct_credits_deducts_exact_amount_when_sufficient() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) + + with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)): + deducted_credits = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3) + + assert deducted_credits == 3 + assert _get_quota_used(engine=engine, pool_id=pool_id) == 5 + + +def test_check_and_deduct_credits_returns_zero_for_non_positive_request() -> None: + assert CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=0) == 0 + + +def test_check_and_deduct_credits_raises_when_pool_is_missing() -> None: + engine = create_engine("sqlite:///:memory:") + TenantCreditPool.__table__.create(engine) + + with ( + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + pytest.raises(QuotaExceededError, match="Credit pool not found"), + ): + CreditPoolService.check_and_deduct_credits(tenant_id=str(uuid4()), credits_required=1) + + +def test_check_and_deduct_credits_raises_when_pool_is_empty() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10) + + with ( + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + pytest.raises(QuotaExceededError, match="No credits remaining"), + ): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1) + + assert _get_quota_used(engine=engine, pool_id=pool_id) == 10 + + +def test_check_and_deduct_credits_raises_without_partial_deduction_when_insufficient() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9) + + with ( + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + pytest.raises(QuotaExceededError, match="Insufficient credits remaining"), + ): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=3) + + assert _get_quota_used(engine=engine, pool_id=pool_id) == 9 + + +def test_check_and_deduct_credits_wraps_unexpected_deduction_errors() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) + + with ( + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")), + pytest.raises(QuotaExceededError, match="Failed to deduct credits"), + ): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=1) + + assert _get_quota_used(engine=engine, pool_id=pool_id) == 2 + + +def test_deduct_credits_capped_returns_zero_for_non_positive_request() -> None: + assert CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=0) == 0 + + +def test_deduct_credits_capped_returns_zero_when_pool_is_missing() -> None: + engine = create_engine("sqlite:///:memory:") + TenantCreditPool.__table__.create(engine) + + with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)): + deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=str(uuid4()), credits_required=1) + + assert deducted_credits == 0 + + +def test_deduct_credits_capped_returns_zero_when_pool_is_empty() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=10) + + with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)): + deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1) + + assert deducted_credits == 0 + assert _get_quota_used(engine=engine, pool_id=pool_id) == 10 + + +def test_deduct_credits_capped_deducts_only_remaining_balance_when_insufficient() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=9) + + with patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)): + deducted_credits = CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=3) + + assert deducted_credits == 1 + assert _get_quota_used(engine=engine, pool_id=pool_id) == 10 + + +def test_deduct_credits_capped_wraps_unexpected_deduction_errors() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) + + with ( + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + patch.object(CreditPoolService, "_get_locked_pool", side_effect=RuntimeError("database unavailable")), + pytest.raises(QuotaExceededError, match="Failed to deduct credits"), + ): + CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1) + + assert _get_quota_used(engine=engine, pool_id=pool_id) == 2 + + +def test_deduct_credits_capped_reraises_quota_exceeded_errors() -> None: + engine, tenant_id, pool_id = _create_engine_with_pool(quota_limit=10, quota_used=2) + + with ( + patch("services.credit_pool_service.db", SimpleNamespace(engine=engine)), + patch.object(CreditPoolService, "_get_locked_pool", side_effect=QuotaExceededError("quota unavailable")), + pytest.raises(QuotaExceededError, match="quota unavailable"), + ): + CreditPoolService.deduct_credits_capped(tenant_id=tenant_id, credits_required=1) + + assert _get_quota_used(engine=engine, pool_id=pool_id) == 2 diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py index 1633194aa8..a78bc7f9d6 100644 --- a/api/tests/unit_tests/services/test_dataset_service_document.py +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -1297,7 +1297,7 @@ class TestDocumentServiceEstimateValidation: """Unit tests for estimate_args_validate branches.""" def test_estimate_args_validate_rejects_missing_info_list(self): - with pytest.raises(ValueError, match="Data source info is required"): + with pytest.raises(ValueError, match="Field required"): DocumentService.estimate_args_validate({}) def test_estimate_args_validate_sets_empty_rules_for_automatic_mode(self): @@ -1357,7 +1357,7 @@ class TestDocumentServiceEstimateValidation: }, } - with pytest.raises(ValueError, match="Summary index model provider name is required"): + with pytest.raises(ValueError, match="Field required"): DocumentService.estimate_args_validate(args) diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 12e7b606fd..44ef5e3acc 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -268,8 +268,8 @@ class TestWebhookServiceUnit: } # Mock file reads - files["file1"].read.return_value = b"content1" - files["file2"].read.return_value = b"content2" + files["file1"].stream.read.return_value = b"content1" + files["file2"].stream.read.return_value = b"content2" webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" @@ -304,8 +304,8 @@ class TestWebhookServiceUnit: "bad_file": MagicMock(filename="test.bad", content_type="text/plain"), } - files["good_file"].read.return_value = b"content" - files["bad_file"].read.side_effect = Exception("Read error") + files["good_file"].stream.read.return_value = b"content" + files["bad_file"].stream.read.side_effect = Exception("Read error") webhook_trigger = MagicMock() webhook_trigger.tenant_id = "test_tenant" diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 642a459e0b..edc952b11f 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -11,6 +11,7 @@ This test suite covers: import json import uuid +from types import SimpleNamespace from typing import Any, cast from unittest.mock import ANY, MagicMock, Mock, patch, sentinel @@ -2649,8 +2650,13 @@ class TestWorkflowServiceHumanInputOperations: mock_node = MagicMock() mock_node.node_data = MagicMock() + mock_node.node_data.user_actions = [ + SimpleNamespace(id="submit", title="card_visa_enterprise_001"), + ] mock_node.node_data.outputs_field_names.return_value = ["field1"] mock_node.node_data.inputs = [] + mock_node.render_form_content_before_submission.return_value = "Ticket: {{#$output.field1#}}" + mock_node.render_form_content_with_outputs.return_value = "Ticket: val1" with ( patch("services.workflow_service.db"), @@ -2670,6 +2676,8 @@ class TestWorkflowServiceHumanInputOperations: ) assert result["__action_id"] == "submit" mock_validate.assert_called_once() + assert result["__action_value"] == "card_visa_enterprise_001" + assert result["__rendered_content"] == "Ticket: val1" mock_saver_cls.return_value.save.assert_called_once() def test_test_human_input_delivery_success(self, service: WorkflowService) -> None: @@ -2850,7 +2858,7 @@ class TestWorkflowServiceFreeNodeExecution: mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data) mock_node_cls.assert_called_once_with( node_id="n-1", - config=sentinel.node_data, + data=sentinel.node_data, graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value, graph_runtime_state=ANY, runtime=mock_runtime_cls.return_value, diff --git a/api/tests/unit_tests/tasks/test_ops_trace_task.py b/api/tests/unit_tests/tasks/test_ops_trace_task.py new file mode 100644 index 0000000000..5844c55c04 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_ops_trace_task.py @@ -0,0 +1,301 @@ +import json +import sys +from contextlib import contextmanager +from types import ModuleType +from unittest.mock import MagicMock, patch + +import pytest +from celery.exceptions import Retry + +from core.ops.entities.config_entity import OPS_TRACE_FAILED_KEY +from core.ops.exceptions import RetryableTraceDispatchError +from tasks.ops_trace_task import process_trace_tasks + + +@contextmanager +def fake_app_context(): + yield + + +class FakeCurrentApp: + def app_context(self): + return fake_app_context() + + +def _install_trace_manager( + trace_instance: MagicMock, + *, + enterprise_enabled: bool = False, + enterprise_trace_cls: MagicMock | None = None, +) -> dict[str, ModuleType]: + ops_trace_manager_module = ModuleType("core.ops.ops_trace_manager") + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id: str) -> MagicMock: + return trace_instance + + telemetry_module = ModuleType("extensions.ext_enterprise_telemetry") + telemetry_module.is_enabled = lambda: enterprise_enabled + + ops_trace_manager_module.OpsTraceManager = StubOpsTraceManager + modules = { + "core.ops.ops_trace_manager": ops_trace_manager_module, + "extensions.ext_enterprise_telemetry": telemetry_module, + } + if enterprise_trace_cls is not None: + enterprise_module = ModuleType("enterprise") + enterprise_telemetry_module = ModuleType("enterprise.telemetry") + enterprise_trace_module = ModuleType("enterprise.telemetry.enterprise_trace") + enterprise_trace_module.EnterpriseOtelTrace = enterprise_trace_cls + modules.update( + { + "enterprise": enterprise_module, + "enterprise.telemetry": enterprise_telemetry_module, + "enterprise.telemetry.enterprise_trace": enterprise_trace_module, + } + ) + return modules + + +def _make_payload() -> str: + return json.dumps({"trace_info": {}, "trace_info_type": None}) + + +def _decode_saved_payload(payload: bytes | str) -> dict[str, object]: + if isinstance(payload, bytes): + payload = payload.decode("utf-8") + return json.loads(payload) + + +def _retryable_dispatch_error() -> RetryableTraceDispatchError: + return RetryableTraceDispatchError("transient trace dispatch failure") + + +def _run_task(file_info: dict[str, str], retries: int = 0) -> None: + process_trace_tasks.push_request(retries=retries) + try: + process_trace_tasks.run(file_info) + finally: + process_trace_tasks.pop_request() + + +def test_process_trace_tasks_retries_retryable_dispatch_failure_and_preserves_payload(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + pending_error = _retryable_dispatch_error() + trace_instance.trace.side_effect = pending_error + retry_error = Retry() + + with ( + patch.dict(sys.modules, _install_trace_manager(trace_instance)), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry, + pytest.raises(Retry), + ): + _run_task(file_info) + + mock_retry.assert_called_once_with( + exc=pending_error, + countdown=process_trace_tasks.default_retry_delay, + ) + mock_delete.assert_not_called() + mock_incr.assert_not_called() + + +def test_process_trace_tasks_marks_enterprise_trace_dispatched_before_retryable_dispatch_retry(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + pending_error = _retryable_dispatch_error() + trace_instance.trace.side_effect = pending_error + retry_error = Retry() + enterprise_tracer = MagicMock() + enterprise_trace_cls = MagicMock(return_value=enterprise_tracer) + + with ( + patch.dict( + sys.modules, + _install_trace_manager( + trace_instance, + enterprise_enabled=True, + enterprise_trace_cls=enterprise_trace_cls, + ), + ), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.save") as mock_save, + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry, + pytest.raises(Retry), + ): + _run_task(file_info) + + enterprise_tracer.trace.assert_called_once_with({}) + saved_path, saved_payload = mock_save.call_args.args + assert saved_path == "ops_trace/app-id/file-id.json" + assert _decode_saved_payload(saved_payload)["_enterprise_trace_dispatched"] is True + mock_retry.assert_called_once_with( + exc=pending_error, + countdown=process_trace_tasks.default_retry_delay, + ) + mock_delete.assert_not_called() + mock_incr.assert_not_called() + + +def test_process_trace_tasks_does_not_mark_failed_enterprise_trace_as_dispatched_before_retry(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + pending_error = _retryable_dispatch_error() + trace_instance.trace.side_effect = pending_error + retry_error = Retry() + enterprise_tracer = MagicMock() + enterprise_tracer.trace.side_effect = RuntimeError("enterprise trace failed") + enterprise_trace_cls = MagicMock(return_value=enterprise_tracer) + + with ( + patch.dict( + sys.modules, + _install_trace_manager( + trace_instance, + enterprise_enabled=True, + enterprise_trace_cls=enterprise_trace_cls, + ), + ), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.save") as mock_save, + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + patch.object(process_trace_tasks, "retry", side_effect=retry_error) as mock_retry, + pytest.raises(Retry), + ): + _run_task(file_info) + + enterprise_tracer.trace.assert_called_once_with({}) + mock_save.assert_not_called() + mock_retry.assert_called_once_with( + exc=pending_error, + countdown=process_trace_tasks.default_retry_delay, + ) + mock_delete.assert_not_called() + mock_incr.assert_not_called() + + +def test_process_trace_tasks_skips_enterprise_trace_when_retry_payload_was_already_dispatched(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + enterprise_trace_cls = MagicMock() + payload = json.dumps({"trace_info": {}, "trace_info_type": None, "_enterprise_trace_dispatched": True}) + + with ( + patch.dict( + sys.modules, + _install_trace_manager( + trace_instance, + enterprise_enabled=True, + enterprise_trace_cls=enterprise_trace_cls, + ), + ), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=payload), + patch("tasks.ops_trace_task.storage.save") as mock_save, + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + ): + _run_task(file_info) + + enterprise_trace_cls.assert_not_called() + trace_instance.trace.assert_called_once_with({}) + mock_save.assert_not_called() + mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json") + mock_incr.assert_not_called() + + +def test_process_trace_tasks_default_retry_window_covers_parent_span_context_ttl(): + assert process_trace_tasks.max_retries * process_trace_tasks.default_retry_delay >= 300 + + +def test_process_trace_tasks_deletes_payload_on_success(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + + with ( + patch.dict(sys.modules, _install_trace_manager(trace_instance)), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + ): + _run_task(file_info) + + trace_instance.trace.assert_called_once_with({}) + mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json") + mock_incr.assert_not_called() + + +def test_process_trace_tasks_deletes_payload_and_counts_terminal_failure(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + trace_instance.trace.side_effect = RuntimeError("trace failed") + + with ( + patch.dict(sys.modules, _install_trace_manager(trace_instance)), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + ): + _run_task(file_info) + + mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json") + mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id") + + +def test_process_trace_tasks_treats_retry_enqueue_failure_as_terminal_failure(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + pending_error = _retryable_dispatch_error() + retry_enqueue_error = RuntimeError("retry enqueue failed") + trace_instance.trace.side_effect = pending_error + + with ( + patch.dict(sys.modules, _install_trace_manager(trace_instance)), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + patch.object(process_trace_tasks, "retry", side_effect=retry_enqueue_error) as mock_retry, + ): + _run_task(file_info) + + mock_retry.assert_called_once_with( + exc=pending_error, + countdown=process_trace_tasks.default_retry_delay, + ) + mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json") + mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id") + + +def test_process_trace_tasks_deletes_payload_and_counts_exhausted_retryable_dispatch_failure(): + file_info = {"app_id": "app-id", "file_id": "file-id"} + trace_instance = MagicMock() + pending_error = _retryable_dispatch_error() + trace_instance.trace.side_effect = pending_error + + with ( + patch.dict(sys.modules, _install_trace_manager(trace_instance)), + patch("tasks.ops_trace_task.current_app", FakeCurrentApp()), + patch("tasks.ops_trace_task.storage.load", return_value=_make_payload()), + patch("tasks.ops_trace_task.storage.delete") as mock_delete, + patch("tasks.ops_trace_task.redis_client.incr") as mock_incr, + patch.object(process_trace_tasks, "retry") as mock_retry, + ): + _run_task(file_info, retries=process_trace_tasks.max_retries) + + mock_retry.assert_not_called() + mock_delete.assert_called_once_with("ops_trace/app-id/file-id.json") + mock_incr.assert_called_once_with(f"{OPS_TRACE_FAILED_KEY}_app-id") diff --git a/api/uv.lock b/api/uv.lock index 9a2bcf104d..70683ad5d2 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -616,16 +616,16 @@ wheels = [ [[package]] name = "boto3" -version = "1.43.3" +version = "1.43.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f2/50/ea184e159c4ac64fef816a72094fb8656eb071361a39ed22c0e3b15a35b4/boto3-1.43.3.tar.gz", hash = "sha256:7c7777862ffc898f05efa566032bbabfe226dbb810e35ec11125817f128bc5c5", size = 113111, upload-time = "2026-05-04T19:34:09.731Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/37/78c630d1308964aa9abf44951d9c4df776546ff37251ec2434944e205c4e/boto3-1.43.6.tar.gz", hash = "sha256:e6315effaf12b890b99956e6f8e2c3000a3f64e4ee91943cec3895ce9a836afb", size = 113153, upload-time = "2026-05-07T20:49:59.694Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/ad/8a6946a329f0127322108e537dc1c0d9f8eea4f1d1231702c073d2e85f46/boto3-1.43.3-py3-none-any.whl", hash = "sha256:fb9fe51849ef2a78198d582756fc06f14f7de27f73e0fa90275d6aa4171eb4d0", size = 140501, upload-time = "2026-05-04T19:34:07.991Z" }, + { url = "https://files.pythonhosted.org/packages/c8/e2/3c2eef44f55eafab256836d1d9479bd6a74f70c26cbfdc0639a0e23e4327/boto3-1.43.6-py3-none-any.whl", hash = "sha256:179601ec2992726a718053bf41e43c223ceba397d31ceab11f64d9c910d9fc3a", size = 140502, upload-time = "2026-05-07T20:49:57.8Z" }, ] [[package]] @@ -648,16 +648,16 @@ bedrock-runtime = [ [[package]] name = "botocore" -version = "1.43.3" +version = "1.43.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/74/ac/cd55f886e17b6b952dbc95b792d3645a73d58586a1400ababe54406073bd/botocore-1.43.3.tar.gz", hash = "sha256:eac6da0fffccf87888ebf4d89f0b2378218a707efa748cd955b838995e944695", size = 15308705, upload-time = "2026-05-04T19:33:56.28Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/a7/23d0f5028011455096a1eeac0ddf3cbe147b3e855e127342f8202552194d/botocore-1.43.6.tar.gz", hash = "sha256:b1e395b347356860398da42e61c808cf1e34b6fa7180cf2b9d87d986e1a06ba0", size = 15336070, upload-time = "2026-05-07T20:49:48.14Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/be/99/1d9e296edf244f47e0508032f20999f8fd40704dd3c5b601fed099424eb6/botocore-1.43.3-py3-none-any.whl", hash = "sha256:ec0769eb0f7c5034856bb406a92698dbc02a3d4be0f78a384747106b161d8ea3", size = 14989027, upload-time = "2026-05-04T19:33:50.81Z" }, + { url = "https://files.pythonhosted.org/packages/e5/c8/6f47223840e8d8cfa8c9f7c0ec1b77970417f257fc885169ff4f6326ce09/botocore-1.43.6-py3-none-any.whl", hash = "sha256:b6d1fdbc6f65a5fe0b7e947823aa37535d3f39f3ba4d21110fab1f55bbbcc04b", size = 15017094, upload-time = "2026-05-07T20:49:44.964Z" }, ] [[package]] @@ -1672,7 +1672,7 @@ requires-dist = [ { name = "aliyun-log-python-sdk", specifier = ">=0.9.44,<1.0.0" }, { name = "azure-identity", specifier = ">=1.25.3,<2.0.0" }, { name = "bleach", specifier = ">=6.3.0" }, - { name = "boto3", specifier = ">=1.43.3" }, + { name = "boto3", specifier = ">=1.43.6" }, { name = "celery", specifier = ">=5.6.3" }, { name = "croniter", specifier = ">=6.2.2" }, { name = "fastopenapi", extras = ["flask"], specifier = "~=0.7.0" }, @@ -1686,10 +1686,10 @@ requires-dist = [ { name = "gevent", specifier = ">=26.4.0" }, { name = "gevent-websocket", specifier = ">=0.10.1" }, { name = "gmpy2", specifier = ">=2.3.0" }, - { name = "google-api-python-client", specifier = ">=2.195.0" }, - { name = "google-cloud-aiplatform", specifier = ">=1.149.0,<2.0.0" }, + { name = "google-api-python-client", specifier = ">=2.196.0" }, + { name = "google-cloud-aiplatform", specifier = ">=1.151.0,<2.0.0" }, { name = "graphon", git = "https://github.com/QuantumGhost/graphon?branch=hitl-form-dev" }, - { name = "gunicorn", specifier = ">=25.3.0" }, + { name = "gunicorn", specifier = ">=26.0.0" }, { name = "httpx", extras = ["socks"], specifier = ">=0.28.1,<1.0.0" }, { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "json-repair", specifier = "~=0.59.4" }, @@ -1787,7 +1787,7 @@ storage = [ { name = "google-cloud-storage", specifier = ">=3.10.1" }, { name = "opendal", specifier = ">=0.46.0" }, { name = "oss2", specifier = ">=2.19.1" }, - { name = "supabase", specifier = ">=2.29.0" }, + { name = "supabase", specifier = ">=2.30.0" }, { name = "tos", specifier = ">=2.9.0" }, ] tools = [ @@ -2840,7 +2840,7 @@ grpc = [ [[package]] name = "google-api-python-client" -version = "2.195.0" +version = "2.196.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2849,9 +2849,9 @@ dependencies = [ { name = "httplib2" }, { name = "uritemplate" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/69/07/08d759b9cb10f48af14b25262dd0d6685ca8cda6c1f9e8a8109f57457205/google_api_python_client-2.195.0.tar.gz", hash = "sha256:c72cf2661c3addf01c880ce60541e83e1df354644b874f7f9d8d5ed2070446ae", size = 14584819, upload-time = "2026-04-30T21:51:50.638Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/f3/34ef8aca7909675fe327f96c1ed927f0520e7acf68af19157e96acc05e76/google_api_python_client-2.196.0.tar.gz", hash = "sha256:9f335d38f6caaa2747bcf64335ed1a9a19047d53e86538eda6a1b17d37f1743d", size = 14628129, upload-time = "2026-05-06T23:47:35.655Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/21/b9/2c71095e31fff57668fec7c07ac897df065f15521d070e63229e13689590/google_api_python_client-2.195.0-py3-none-any.whl", hash = "sha256:753e62057f23049a89534bea0162b60fe391b85fb86d80bcdf884d05ec91c5bf", size = 15162418, upload-time = "2026-04-30T21:51:47.444Z" }, + { url = "https://files.pythonhosted.org/packages/99/c7/1817b4edf966d5afcac1c0781ca36d621bc0cb58104c4e7c2a475ab185f7/google_api_python_client-2.196.0-py3-none-any.whl", hash = "sha256:2591e9b47dcb17e4e62a09370aaee3bcf323af8f28ccecdabcd0a42a23ca4db5", size = 15206663, upload-time = "2026-05-06T23:47:32.886Z" }, ] [[package]] @@ -2887,7 +2887,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.149.0" +version = "1.151.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -2903,9 +2903,9 @@ dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/42/2c/fba4adc56f74c0ee0fbd91a39d414ca2c3588dd8b71f9be8a507015ca886/google_cloud_aiplatform-1.149.0.tar.gz", hash = "sha256:a4d73485bf1d727a9e1bbbd13d08d7031490686bbf7d125eb905c1a6c1559a35", size = 10451466, upload-time = "2026-04-27T23:11:54.513Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/f6/e2fbe175a011f5080da8c1f7d9169a6875a00ea2c7bee4193d952b097400/google_cloud_aiplatform-1.151.0.tar.gz", hash = "sha256:2f29b1853f790a7371a746c747bf1f664380b534254682441acd4b5ee26fafd2", size = 10617421, upload-time = "2026-05-07T21:56:52.91Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/a0/27719ba23967ef62e52a1d54e013e0fc174bdab8dd84fb300bab9bf0d4a3/google_cloud_aiplatform-1.149.0-py2.py3-none-any.whl", hash = "sha256:e6b5299fa5d303e971cb29a19f03fdbb7b1e3b9d2faa3a788ca933341fba2f2e", size = 8570410, upload-time = "2026-04-27T23:11:50.495Z" }, + { url = "https://files.pythonhosted.org/packages/f6/4a/cd35f8ba622d563b1335222284d2838aa789b953b40516b1b997e50fe5b6/google_cloud_aiplatform-1.151.0-py2.py3-none-any.whl", hash = "sha256:61372bb0923b14b8027f45b83393452df3a85bf4ea86fa48e08844fb5ec50049", size = 8732627, upload-time = "2026-05-07T21:56:49.014Z" }, ] [[package]] @@ -3058,8 +3058,8 @@ httpx = [ [[package]] name = "graphon" -version = "0.2.2" -source = { git = "https://github.com/QuantumGhost/graphon?branch=hitl-form-dev#91941bfcbc1fbd0107ada2423f8ffaa628957ae7" } +version = "0.3.1" +source = { git = "https://github.com/QuantumGhost/graphon?branch=hitl-form-dev#b8f72b598271032cd088a08cc68418a1a5f65479" } dependencies = [ { name = "charset-normalizer" }, { name = "httpx" }, @@ -3216,14 +3216,14 @@ wheels = [ [[package]] name = "gunicorn" -version = "25.3.0" +version = "26.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c4/f4/e78fa054248fab913e2eab0332c6c2cb07421fca1ce56d8fe43b6aef57a4/gunicorn-25.3.0.tar.gz", hash = "sha256:f74e1b2f9f76f6cd1ca01198968bd2dd65830edc24b6e8e4d78de8320e2fe889", size = 634883, upload-time = "2026-03-27T00:00:26.092Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/b7/a4a3f632f823e432ce6bc65f62961b7980c898c77f075a2f7118cb3846fe/gunicorn-26.0.0.tar.gz", hash = "sha256:ca9346f85e3a4aeeb64d491045c16b9a35647abd37ea15efe53080eb8b090baf", size = 727286, upload-time = "2026-05-05T06:38:25.529Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/43/c8/8aaf447698c4d59aa853fd318eed300b5c9e44459f242ab8ead6c9c09792/gunicorn-25.3.0-py3-none-any.whl", hash = "sha256:cacea387dab08cd6776501621c295a904fe8e3b7aae9a1a3cbb26f4e7ed54660", size = 208403, upload-time = "2026-03-27T00:00:27.386Z" }, + { url = "https://files.pythonhosted.org/packages/e6/40/9c2384fc2be4ad25dd4a49decd5ad9ea5a3639814c11bd40ab77cb9f0a14/gunicorn-26.0.0-py3-none-any.whl", hash = "sha256:40233d26a5f0d1872916188c276e21641155111c2853f0c2cd55260aec0d24fc", size = 212009, upload-time = "2026-05-05T06:38:23.007Z" }, ] [[package]] @@ -4615,32 +4615,32 @@ wheels = [ [[package]] name = "opentelemetry-exporter-otlp" -version = "1.41.0" +version = "1.41.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-exporter-otlp-proto-grpc" }, { name = "opentelemetry-exporter-otlp-proto-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/65/b7/845565a2ab5d22c1486bc7729a06b05cd0964c61539d766e1f107c9eea0c/opentelemetry_exporter_otlp-1.41.0.tar.gz", hash = "sha256:97ff847321f8d4c919032a67d20d3137fb7b34eac0c47f13f71112858927fc5b", size = 6152, upload-time = "2026-04-09T14:38:35.895Z" } +sdist = { url = "https://files.pythonhosted.org/packages/42/84/d55baf8e1a222f40282956083e67de9fa92d5fa451108df4839505fa2a24/opentelemetry_exporter_otlp-1.41.1.tar.gz", hash = "sha256:299a2f0541ca175df186f5ac58fd5db177ba1e9b72b0826049062f750d55b47f", size = 6152, upload-time = "2026-04-24T13:15:40.006Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e0/f2/f1076fff152858773f22cda146713f9ae3661795af6bacd411a76f2151ac/opentelemetry_exporter_otlp-1.41.0-py3-none-any.whl", hash = "sha256:443b6a45c990ae4c55e147f97049a86c5f5b704f3d78b48b44a073a886ec4d6e", size = 7022, upload-time = "2026-04-09T14:38:13.934Z" }, + { url = "https://files.pythonhosted.org/packages/6d/d5/ea4aa7dfc458fd537bd9519ea0e7226eef2a6212dfe952694984167daaba/opentelemetry_exporter_otlp-1.41.1-py3-none-any.whl", hash = "sha256:db276c5a80c02b063994e80950d00ca1bfddcf6520f608335b7dc2db0c0eb9c6", size = 7025, upload-time = "2026-04-24T13:15:17.839Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.41.0" +version = "1.41.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-proto" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8c/28/e8eca94966fe9a1465f6094dc5ddc5398473682180279c94020bc23b4906/opentelemetry_exporter_otlp_proto_common-1.41.0.tar.gz", hash = "sha256:966bbce537e9edb166154779a7c4f8ab6b8654a03a28024aeaf1a3eacb07d6ee", size = 20411, upload-time = "2026-04-09T14:38:36.572Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/fa/f9e3bd3c4d692b3ce9a2880a167d1f79681a1bea11f00d5bf76adc03e6ea/opentelemetry_exporter_otlp_proto_common-1.41.1.tar.gz", hash = "sha256:0e253156ea9c36b0bd3d2440c5c9ba7dd1f3fb64ba7a08fc85fbac536b56e1fb", size = 20409, upload-time = "2026-04-24T13:15:40.924Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/26/c4/78b9bf2d9c1d5e494f44932988d9d91c51a66b9a7b48adf99b62f7c65318/opentelemetry_exporter_otlp_proto_common-1.41.0-py3-none-any.whl", hash = "sha256:7a99177bf61f85f4f9ed2072f54d676364719c066f6d11f515acc6c745c7acf0", size = 18366, upload-time = "2026-04-09T14:38:15.135Z" }, + { url = "https://files.pythonhosted.org/packages/29/48/bce76d3ea772b609757e9bc844e02ab408a6446609bf74fb562062ba6b71/opentelemetry_exporter_otlp_proto_common-1.41.1-py3-none-any.whl", hash = "sha256:10da74dad6a49344b9b7b21b6182e3060373a235fde1528616d5f01f92e66aa9", size = 18366, upload-time = "2026-04-24T13:15:18.917Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.41.0" +version = "1.41.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "googleapis-common-protos" }, @@ -4651,14 +4651,14 @@ dependencies = [ { name = "opentelemetry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/42/46/d75a3f8c91915f2e58f61d0a2e4ada63891e7c7a37a20ff7949ba184a6b2/opentelemetry_exporter_otlp_proto_grpc-1.41.0.tar.gz", hash = "sha256:f704201251c6f65772b11bddea1c948000554459101bdbb0116e0a01b70592f6", size = 25754, upload-time = "2026-04-09T14:38:37.423Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1e/9b/e4503060b8695579dbaad187dc8cef4554188de68748c88060599b77489e/opentelemetry_exporter_otlp_proto_grpc-1.41.1.tar.gz", hash = "sha256:b05df8fa1333dc9a3fda36b676b96b5095ab6016d3f0c3296d430d629ba1443b", size = 25755, upload-time = "2026-04-24T13:15:41.93Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/81/f6/b09e2e0c9f0b5750cebc6eaf31527b910821453cef40a5a0fe93550422b2/opentelemetry_exporter_otlp_proto_grpc-1.41.0-py3-none-any.whl", hash = "sha256:3a1a86bd24806ccf136ec9737dbfa4c09b069f9130ff66b0acb014f9c5255fd1", size = 20299, upload-time = "2026-04-09T14:38:17.01Z" }, + { url = "https://files.pythonhosted.org/packages/ac/f2/c54f33c92443d087703e57e52e55f22f111373a5c4c4aa349ea60efe512e/opentelemetry_exporter_otlp_proto_grpc-1.41.1-py3-none-any.whl", hash = "sha256:537926dcef951136992479af1d9cd88f25e33d56c530e9f020ed57774dca2f94", size = 20297, upload-time = "2026-04-24T13:15:20.212Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-http" -version = "1.41.0" +version = "1.41.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "googleapis-common-protos" }, @@ -4669,9 +4669,9 @@ dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/19/63/d9f43cd75f3fabb7e01148c89cfa9491fc18f6580a6764c554ff7c953c46/opentelemetry_exporter_otlp_proto_http-1.41.0.tar.gz", hash = "sha256:dcd6e0686f56277db4eecbadd5262124e8f2cc739cadbc3fae3d08a12c976cf5", size = 24139, upload-time = "2026-04-09T14:38:38.128Z" } +sdist = { url = "https://files.pythonhosted.org/packages/33/5b/9d3c7f70cca10136ba82a81e738dee626c8e7fc61c6887ea9a58bf34c606/opentelemetry_exporter_otlp_proto_http-1.41.1.tar.gz", hash = "sha256:4747a9604c8550ab38c6fd6180e2fcb80de3267060bef2c306bad3cb443302bc", size = 24139, upload-time = "2026-04-24T13:15:42.977Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/64/b5/a214cd907eedc17699d1c2d602288ae17cb775526df04db3a3b3585329d2/opentelemetry_exporter_otlp_proto_http-1.41.0-py3-none-any.whl", hash = "sha256:a9c4ee69cce9c3f4d7ee736ad1b44e3c9654002c0816900abbafd9f3cf289751", size = 22673, upload-time = "2026-04-09T14:38:18.349Z" }, + { url = "https://files.pythonhosted.org/packages/ba/4d/ef07ff2fc630849f2080ae0ae73a61f67257905b7ac79066640bfa0c5739/opentelemetry_exporter_otlp_proto_http-1.41.1-py3-none-any.whl", hash = "sha256:1a21e8f49c7a946d935551e90947d6c3eb39236723c6624401da0f33d68edcb4", size = 22673, upload-time = "2026-04-24T13:15:21.313Z" }, ] [[package]] @@ -4829,14 +4829,14 @@ wheels = [ [[package]] name = "opentelemetry-proto" -version = "1.41.0" +version = "1.41.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e0/d9/08e3dc6156878713e8c811682bc76151f5fe1a3cb7f3abda3966fd56e71e/opentelemetry_proto-1.41.0.tar.gz", hash = "sha256:95d2e576f9fb1800473a3e4cfcca054295d06bdb869fda4dc9f4f779dc68f7b6", size = 45669, upload-time = "2026-04-09T14:38:45.978Z" } +sdist = { url = "https://files.pythonhosted.org/packages/99/e8/633c6d8a9c8840338b105907e55c32d3da1983abab5e52f899f72a82c3d1/opentelemetry_proto-1.41.1.tar.gz", hash = "sha256:4b9d2eb631237ea43b80e16c073af438554e32bc7e9e3f8ca4a9582f900020e5", size = 45670, upload-time = "2026-04-24T13:15:49.768Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/49/8c/65ef7a9383a363864772022e822b5d5c6988e6f9dabeebb9278f5b86ebc3/opentelemetry_proto-1.41.0-py3-none-any.whl", hash = "sha256:b970ab537309f9eed296be482c3e7cca05d8aca8165346e929f658dbe153b247", size = 72074, upload-time = "2026-04-09T14:38:29.38Z" }, + { url = "https://files.pythonhosted.org/packages/e4/1e/5cd77035e3e82070e2265a63a760f715aacd3cb16dddc7efee913f297fcc/opentelemetry_proto-1.41.1-py3-none-any.whl", hash = "sha256:0496713b804d127a4147e32849fbaf5683fac8ee98550e8e7679cd706c289720", size = 72076, upload-time = "2026-04-24T13:15:32.542Z" }, ] [[package]] @@ -5180,7 +5180,7 @@ wheels = [ [[package]] name = "postgrest" -version = "2.29.0" +version = "2.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "deprecation" }, @@ -5188,9 +5188,9 @@ dependencies = [ { name = "pydantic" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/52/98/f216b8b5c4d116ab6a2fb21339b5821da279ee773e163612418e1c56c012/postgrest-2.29.0.tar.gz", hash = "sha256:a87081858f627fcd57e8e7137004a1ef0adbdf0dbdfed1384e9ea1d7a9c525ec", size = 14217, upload-time = "2026-04-24T13:13:00.281Z" } +sdist = { url = "https://files.pythonhosted.org/packages/56/7c/54e7be05adc9fd6fd98dc572ddfc8982d45bec314a55711e37277d440698/postgrest-2.30.0.tar.gz", hash = "sha256:4f89eec56ce605ab6fbddd9b96d526a9bb44962796d44a5d85cb77640eb766c3", size = 14430, upload-time = "2026-05-06T17:35:21.559Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/0b/08b670a93a90d625c557b9e64b8a5fdeec80c3542d2d0265f0b4d6b16646/postgrest-2.29.0-py3-none-any.whl", hash = "sha256:3ee48e146f726272733d20e2b12de354cdb6cb9dd9cc3a61ed97ce69047aeb96", size = 22735, upload-time = "2026-04-24T13:12:58.405Z" }, + { url = "https://files.pythonhosted.org/packages/22/aa/ff2e09f99f95ea96fddeb373646bf907dd89a24fc00b5d38e5674ca7c9ca/postgrest-2.30.0-py3-none-any.whl", hash = "sha256:30631e7993da542419f4217cf3b60aa641084731ea15e66a18526a3a52e40a7d", size = 23108, upload-time = "2026-05-06T17:35:20.531Z" }, ] [[package]] @@ -6177,16 +6177,16 @@ wheels = [ [[package]] name = "realtime" -version = "2.29.0" +version = "2.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6e/f1/08c42a42653942fadfbef495d5b0239356140e7186cc528704956c5f06d4/realtime-2.29.0.tar.gz", hash = "sha256:8efe4a1b3a548a5fda09de701bd041fa0970c5a2fe7d13db0b9861ce11828be2", size = 18715, upload-time = "2026-04-24T13:13:02.315Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/a2/0328d49d3b5fb427068e9200e7de5b0d708d021a1ad98d004bc685d2529e/realtime-2.30.0.tar.gz", hash = "sha256:7aa593da52ed5f92c34ec4e50e32043afa62f219c94f717ad64a66ab0ef9f1ba", size = 18718, upload-time = "2026-05-06T17:35:23.925Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/77/48/f6375c0a24923beb988f0c71c052604c96641cf43c2d22b91ec1df86afa0/realtime-2.29.0-py3-none-any.whl", hash = "sha256:1a4891e6c82e88ac9d96ac715e435e086f6f8c7665212a8717346de829cbb509", size = 22374, upload-time = "2026-04-24T13:13:01.103Z" }, + { url = "https://files.pythonhosted.org/packages/b4/75/1b2cfc949595e22d8c05a2aa2cfc222921f7f94177d7e8a90542f3f73b33/realtime-2.30.0-py3-none-any.whl", hash = "sha256:7c93b63d2cf99aa1da4fa8826b03b00cd32f7b38abb27ff47b19eb5dcb5707c6", size = 22376, upload-time = "2026-05-06T17:35:22.568Z" }, ] [[package]] @@ -6668,7 +6668,7 @@ wheels = [ [[package]] name = "storage3" -version = "2.29.0" +version = "2.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "deprecation" }, @@ -6677,9 +6677,9 @@ dependencies = [ { name = "pyiceberg" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d7/be/771246434b5caf3c6187bfdc932eaede00bf5f2937b47475ab25209ede3e/storage3-2.29.0.tar.gz", hash = "sha256:b0cc2f6714655d725c998d2c5ae8c6fb4f56a513bd31e4f85770df557fe021e3", size = 20160, upload-time = "2026-04-24T13:13:04.626Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9b/b2/6df208d64630744704d00f2c07197170390d6b4d0098617740f6a7a4fa98/storage3-2.30.0.tar.gz", hash = "sha256:b74e3cac149f2c0553dcb5f4d55d8c35d420d88183a1a2df77727d482665972b", size = 20162, upload-time = "2026-05-06T17:35:25.71Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fc/c3/790c31866f52c13b26f108b45759bf50dafae3a0bafb4511fadc98ba7c33/storage3-2.29.0-py3-none-any.whl", hash = "sha256:043ef7ff27cc8b9da12be403cf78ee4586180edfcf62b227ff61e1bd79594b06", size = 28284, upload-time = "2026-04-24T13:13:03.338Z" }, + { url = "https://files.pythonhosted.org/packages/91/5c/bb8c8cc448cfae671c4ffee67f3651892ea59b341f27bed54666190eb8ef/storage3-2.30.0-py3-none-any.whl", hash = "sha256:2bd23a34011c018bd9c130d8a70a09ebd060ae80d946c6204a6fc08161ad728d", size = 28284, upload-time = "2026-05-06T17:35:24.659Z" }, ] [[package]] @@ -6705,7 +6705,7 @@ wheels = [ [[package]] name = "supabase" -version = "2.29.0" +version = "2.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -6716,37 +6716,37 @@ dependencies = [ { name = "supabase-functions" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/51/a0/2407d616fdf68e8632bbbfb063d1685c38377ac0199e8ca11deaea1f3bf0/supabase-2.29.0.tar.gz", hash = "sha256:a88c4a4eb50fbb903e2e962fbc7c27733b00589140139f9e837bc9fe30dd3615", size = 9689, upload-time = "2026-04-24T13:13:06.728Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/a6/d2b17021c2db1a9d219c383e0762ac03a62b25468e61ab126b6b561c2f21/supabase-2.30.0.tar.gz", hash = "sha256:efdba41d474038ed220736ba4e64946df56043057ad785c4c3499d27e459975c", size = 9689, upload-time = "2026-05-06T17:35:27.781Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/22/52/232f6bbf5326e04ae12e2ef04a24f011a0d7cab379a8b9698652bc8ff78f/supabase-2.29.0-py3-none-any.whl", hash = "sha256:16c3ec4b7094f6b92efc5cd3bb3f96826d3b6dd5d24fe15c89c81166efce88fe", size = 16633, upload-time = "2026-04-24T13:13:05.722Z" }, + { url = "https://files.pythonhosted.org/packages/f0/82/d213be7d0ce0bb18018744c0ee38ba0d6648d41dbc46ac8558cffe80541f/supabase-2.30.0-py3-none-any.whl", hash = "sha256:f9b259194554f7bfd2dca6c23261f2df588016ca18b18e774f4d85bc941edb03", size = 16634, upload-time = "2026-05-06T17:35:26.696Z" }, ] [[package]] name = "supabase-auth" -version = "2.29.0" +version = "2.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx", extra = ["http2"] }, { name = "pydantic" }, { name = "pyjwt", extra = ["crypto"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/51/7f/7ceeb4c7a2caa188062e934897f0e08e1af0a0e47e376c7645c26b4c39d8/supabase_auth-2.29.0.tar.gz", hash = "sha256:46efc6a3455a23957b846dc974303a844ba0413718cfa899425477ac977f95b3", size = 39154, upload-time = "2026-04-24T13:13:08.509Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/8a/48bbbe0b6703d0670b67e45b90d6a791fd01aace67443d286f760bf48895/supabase_auth-2.30.0.tar.gz", hash = "sha256:6138a53a306a95ed59c03d4e4975469dfc3343a0ade33cc4b37e4ef967ad83f8", size = 39135, upload-time = "2026-05-06T17:35:30.371Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f1/ac/3c35cf52281f940b9497cf17abfc5c2050ca49f342d60cfafe22dac3482b/supabase_auth-2.29.0-py3-none-any.whl", hash = "sha256:64de6ef8cae80f97d3aa8d5ca507d5427dda5c89885c0bcfe9f8b0263b6fb9a4", size = 48379, upload-time = "2026-04-24T13:13:07.417Z" }, + { url = "https://files.pythonhosted.org/packages/db/40/a99cb4373353bcbf302d962e51da9eac78b3b0f257eb0362c0852b1667f4/supabase_auth-2.30.0-py3-none-any.whl", hash = "sha256:e85e1f51ec0de2172c3a2a8514205f71731a9914f9a770ed199ac0cf054bc82c", size = 48352, upload-time = "2026-05-06T17:35:28.936Z" }, ] [[package]] name = "supabase-functions" -version = "2.29.0" +version = "2.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx", extra = ["http2"] }, { name = "strenum" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e5/19/1a1d22749f38f2a6cbca93a6f5a35c9f816c2c3c06bfaa077fa336e90537/supabase_functions-2.29.0.tar.gz", hash = "sha256:0f8a14a2ea9f12b1c208f61dc6f55e2f4b1121f81bf01c08f9b487d22888744d", size = 4683, upload-time = "2026-04-24T13:13:10.432Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/e6/5cd8559ec2bb332e6027840c1be292f9989c2fc7b47bf40800aec5586791/supabase_functions-2.30.0.tar.gz", hash = "sha256:025acfd25f1c000ba43d0f7b8e366b0d2e9dfc784b842528e21973eb33006113", size = 4683, upload-time = "2026-05-06T17:35:32.246Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e0/10/6f8ef0b408ade76b5a439afab588ce5849e9604a23040ca73cfe0b90cb9e/supabase_functions-2.29.0-py3-none-any.whl", hash = "sha256:6f08de52eec5820eae53616868b85e849e181beffaa5d05b8ea1708ceae5e48e", size = 8799, upload-time = "2026-04-24T13:13:09.214Z" }, + { url = "https://files.pythonhosted.org/packages/53/da/9dedab32775df04cc22ca72f194b78e895d940f195bed3e02882a65daa9b/supabase_functions-2.30.0-py3-none-any.whl", hash = "sha256:92419459f102767b954cd034856e4ded8e34c78660b32442d66c8b2899c68011", size = 8803, upload-time = "2026-05-06T17:35:31.342Z" }, ] [[package]] @@ -7620,11 +7620,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.6.3" +version = "2.7.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/0c/06f8b233b8fd13b9e5ee11424ef85419ba0d8ba0b3138bf360be2ff56953/urllib3-2.7.0.tar.gz", hash = "sha256:231e0ec3b63ceb14667c67be60f2f2c40a518cb38b03af60abc813da26505f4c", size = 433602, upload-time = "2026-05-07T16:13:18.596Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, + { url = "https://files.pythonhosted.org/packages/7f/3e/5db95bcf282c52709639744ca2a8b149baccf648e39c8cc87553df9eae0c/urllib3-2.7.0-py3-none-any.whl", hash = "sha256:9fb4c81ebbb1ce9531cce37674bbc6f1360472bc18ca9a553ede278ef7276897", size = 131087, upload-time = "2026-05-07T16:13:17.151Z" }, ] [[package]] diff --git a/dev/pytest/pytest_full.sh b/dev/pytest/pytest_full.sh index 2989a74ad8..ca09aeb729 100755 --- a/dev/pytest/pytest_full.sh +++ b/dev/pytest/pytest_full.sh @@ -15,7 +15,7 @@ mkdir -p "${OPENDAL_FS_ROOT}" # Prepare env files like CI cp -n docker/.env.example docker/.env || true -cp -n docker/middleware.env.example docker/middleware.env || true +cp -n docker/envs/middleware.env.example docker/middleware.env || true cp -n api/tests/integration_tests/.env.example api/tests/integration_tests/.env || true # Expose service ports (same as CI) without leaving the repo dirty diff --git a/dev/setup b/dev/setup index 4236ff7fa7..1d2501a48b 100755 --- a/dev/setup +++ b/dev/setup @@ -8,7 +8,7 @@ API_ENV_EXAMPLE="$ROOT/api/.env.example" API_ENV="$ROOT/api/.env" WEB_ENV_EXAMPLE="$ROOT/web/.env.example" WEB_ENV="$ROOT/web/.env.local" -MIDDLEWARE_ENV_EXAMPLE="$ROOT/docker/middleware.env.example" +MIDDLEWARE_ENV_EXAMPLE="$ROOT/docker/envs/middleware.env.example" MIDDLEWARE_ENV="$ROOT/docker/middleware.env" # 1) Copy api/.env.example -> api/.env @@ -17,7 +17,7 @@ cp "$API_ENV_EXAMPLE" "$API_ENV" # 2) Copy web/.env.example -> web/.env.local cp "$WEB_ENV_EXAMPLE" "$WEB_ENV" -# 3) Copy docker/middleware.env.example -> docker/middleware.env +# 3) Copy docker/envs/middleware.env.example -> docker/middleware.env cp "$MIDDLEWARE_ENV_EXAMPLE" "$MIDDLEWARE_ENV" # 4) Install deps diff --git a/docker/.env.example b/docker/.env.example index 82bd837ffb..c708a40c15 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1,5 +1,6 @@ # ------------------------------------------------------------------ # Essential defaults for Docker Compose deployments. +# Only include variables required for services to start. # # For a default deployment, copy this file to .env and run: # docker compose up -d @@ -27,14 +28,16 @@ LANG=C.UTF-8 LC_ALL=C.UTF-8 PYTHONIOENCODING=utf-8 UV_CACHE_DIR=/tmp/.uv-cache -SECRET_KEY=sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U +# Leave empty to auto-generate a persistent key in the storage directory. +SECRET_KEY= INIT_PASSWORD= DEPLOY_ENV=PRODUCTION CHECK_UPDATE_URL=https://updates.dify.ai OPENAI_API_BASE=https://api.openai.com/v1 MIGRATION_ENABLED=true FILES_ACCESS_TIMEOUT=300 -ENABLE_COLLABORATION_MODE=false +# Remove `collaboration` from COMPOSE_PROFILES to stop the dedicated websocket service. +ENABLE_COLLABORATION_MODE=true # Logging and server workers LOG_LEVEL=INFO @@ -52,6 +55,9 @@ DIFY_PORT=5001 SERVER_WORKER_AMOUNT=1 SERVER_WORKER_CLASS=gevent SERVER_WORKER_CONNECTIONS=10 +API_WEBSOCKET_WORKER_CLASS=geventwebsocket.gunicorn.workers.GeventWebSocketWorker +API_WEBSOCKET_WORKER_CONNECTIONS=1000 +API_WEBSOCKET_GUNICORN_TIMEOUT=360 GUNICORN_TIMEOUT=360 CELERY_WORKER_CLASS= CELERY_WORKER_AMOUNT=4 @@ -246,6 +252,7 @@ NGINX_KEEPALIVE_TIMEOUT=65 NGINX_PROXY_READ_TIMEOUT=3600s NGINX_PROXY_SEND_TIMEOUT=3600s NGINX_ENABLE_CERTBOT_CHALLENGE=false +NGINX_SOCKET_IO_UPSTREAM=api_websocket:5001 EXPOSE_NGINX_PORT=80 EXPOSE_NGINX_SSL_PORT=443 -COMPOSE_PROFILES=${VECTOR_STORE:-weaviate},${DB_TYPE:-postgresql} +COMPOSE_PROFILES=${VECTOR_STORE:-weaviate},${DB_TYPE:-postgresql},collaboration diff --git a/docker/README.md b/docker/README.md index a2d9b2eeba..26b1dac9ac 100644 --- a/docker/README.md +++ b/docker/README.md @@ -87,7 +87,7 @@ The root `.env.example` file contains the essential startup settings. Optional a 1. **Server Configuration**: - `LOG_LEVEL`, `DEBUG`, `FLASK_DEBUG`: Logging and debug settings. - - `SECRET_KEY`: A key for encrypting session cookies and other sensitive data. + - `SECRET_KEY`: A key for signing sessions, JWTs, and file URLs. Leave it empty to let Dify generate a persistent key in the storage directory, or set a unique value yourself. 1. **Database Configuration**: diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 0f65c38098..72c9d4fd90 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -261,6 +261,31 @@ services: - ssrf_proxy_network - default + # WebSocket service for workflow collaboration. + api_websocket: + <<: *shared-api-worker-config + image: langgenius/dify-api:1.14.0 + profiles: + - collaboration + environment: + MODE: api + SERVER_WORKER_AMOUNT: 1 + SERVER_WORKER_CLASS: ${API_WEBSOCKET_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} + SERVER_WORKER_CONNECTIONS: ${API_WEBSOCKET_WORKER_CONNECTIONS:-1000} + GUNICORN_TIMEOUT: ${API_WEBSOCKET_GUNICORN_TIMEOUT:-360} + depends_on: + db_postgres: + condition: service_healthy + required: false + db_mysql: + condition: service_healthy + required: false + redis: + condition: service_started + networks: + - ssrf_proxy_network + - default + # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: @@ -661,6 +686,7 @@ services: NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false} + NGINX_SOCKET_IO_UPSTREAM: ${NGINX_SOCKET_IO_UPSTREAM:-api_websocket:5001} CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-} depends_on: - api diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 0f8458a58f..c1d75e01f4 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -267,6 +267,31 @@ services: - ssrf_proxy_network - default + # WebSocket service for workflow collaboration. + api_websocket: + <<: *shared-api-worker-config + image: langgenius/dify-api:1.14.0 + profiles: + - collaboration + environment: + MODE: api + SERVER_WORKER_AMOUNT: 1 + SERVER_WORKER_CLASS: ${API_WEBSOCKET_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} + SERVER_WORKER_CONNECTIONS: ${API_WEBSOCKET_WORKER_CONNECTIONS:-1000} + GUNICORN_TIMEOUT: ${API_WEBSOCKET_GUNICORN_TIMEOUT:-360} + depends_on: + db_postgres: + condition: service_healthy + required: false + db_mysql: + condition: service_healthy + required: false + redis: + condition: service_started + networks: + - ssrf_proxy_network + - default + # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: @@ -667,6 +692,7 @@ services: NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false} + NGINX_SOCKET_IO_UPSTREAM: ${NGINX_SOCKET_IO_UPSTREAM:-api_websocket:5001} CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-} depends_on: - api diff --git a/docker/envs/core-services/shared.env.example b/docker/envs/core-services/shared.env.example index 2a57f6954a..80cfe42c38 100644 --- a/docker/envs/core-services/shared.env.example +++ b/docker/envs/core-services/shared.env.example @@ -16,7 +16,8 @@ CHECK_UPDATE_URL=https://updates.dify.ai OPENAI_API_BASE=https://api.openai.com/v1 MIGRATION_ENABLED=true FILES_ACCESS_TIMEOUT=300 -ENABLE_COLLABORATION_MODE=false +# Remove `collaboration` from COMPOSE_PROFILES to stop the dedicated websocket service. +ENABLE_COLLABORATION_MODE=true CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1 CELERY_TASK_ANNOTATIONS=null AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net @@ -70,6 +71,8 @@ LOG_TZ=UTC DEBUG=false FLASK_DEBUG=false ENABLE_REQUEST_LOGGING=False +OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60 +OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5 WORKFLOW_LOG_CLEANUP_ENABLED=false WORKFLOW_LOG_RETENTION_DAYS=30 WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100 @@ -87,6 +90,9 @@ DIFY_PORT=5001 SERVER_WORKER_AMOUNT=1 SERVER_WORKER_CLASS=gevent SERVER_WORKER_CONNECTIONS=10 +API_WEBSOCKET_WORKER_CLASS=geventwebsocket.gunicorn.workers.GeventWebSocketWorker +API_WEBSOCKET_WORKER_CONNECTIONS=1000 +API_WEBSOCKET_GUNICORN_TIMEOUT=360 CELERY_SENTINEL_PASSWORD= S3_ACCESS_KEY= S3_SECRET_KEY= @@ -399,7 +405,7 @@ TABLESTORE_ENDPOINT=https://instance-name.cn-hangzhou.ots.aliyuncs.com TABLESTORE_INSTANCE_NAME=instance-name CLICKZETTA_USERNAME= CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance -COMPOSE_PROFILES=${VECTOR_STORE:-weaviate},${DB_TYPE:-postgresql} +COMPOSE_PROFILES=${VECTOR_STORE:-weaviate},${DB_TYPE:-postgresql},collaboration EXPOSE_NGINX_PORT=80 EXPOSE_NGINX_SSL_PORT=443 POSITION_TOOL_PINS= diff --git a/docker/envs/infrastructure/nginx.env.example b/docker/envs/infrastructure/nginx.env.example index fbe86680ba..fcb369a47d 100644 --- a/docker/envs/infrastructure/nginx.env.example +++ b/docker/envs/infrastructure/nginx.env.example @@ -15,3 +15,4 @@ NGINX_KEEPALIVE_TIMEOUT=65 NGINX_PROXY_READ_TIMEOUT=3600s NGINX_PROXY_SEND_TIMEOUT=3600s NGINX_ENABLE_CERTBOT_CHALLENGE=false +NGINX_SOCKET_IO_UPSTREAM=api_websocket:5001 diff --git a/docker/envs/security.env.example b/docker/envs/security.env.example index 787aef2706..d7556d91e5 100644 --- a/docker/envs/security.env.example +++ b/docker/envs/security.env.example @@ -36,5 +36,6 @@ TIDB_PUBLIC_KEY=dify TIDB_PRIVATE_KEY=dify VIKINGDB_ACCESS_KEY=your-ak VIKINGDB_SECRET_KEY=your-sk -SECRET_KEY=sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U +# Leave empty to auto-generate a persistent key in the storage directory. +SECRET_KEY= INIT_PASSWORD= diff --git a/docker/nginx/conf.d/default.conf.template b/docker/nginx/conf.d/default.conf.template index 94a748290f..64c720ca2b 100644 --- a/docker/nginx/conf.d/default.conf.template +++ b/docker/nginx/conf.d/default.conf.template @@ -15,7 +15,9 @@ server { } location /socket.io/ { - proxy_pass http://api:5001; + resolver 127.0.0.11 valid=30s ipv6=off; + set $socket_io_upstream ${NGINX_SOCKET_IO_UPSTREAM}; + proxy_pass http://$socket_io_upstream; include proxy.conf; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "upgrade"; diff --git a/e2e/AGENTS.md b/e2e/AGENTS.md index e56aab20a7..c05b5105be 100644 --- a/e2e/AGENTS.md +++ b/e2e/AGENTS.md @@ -31,7 +31,7 @@ pnpm -C e2e check `pnpm install` is resolved through the repository workspace and uses the shared root lockfile plus `pnpm-workspace.yaml`. -Use `pnpm check` as the default local verification step after editing E2E TypeScript, Cucumber support code, or feature glue. It runs formatting, linting, and type checks for this package. +Use `pnpm -C e2e check` as the default local verification step after editing E2E TypeScript, Cucumber support code, or feature glue. It runs formatting, linting, and type checks for this package. Common commands: @@ -68,8 +68,8 @@ flowchart TD C --> D["Cucumber loads config, steps, and support modules"] D --> E["BeforeAll bootstraps shared auth state via /install"] E --> F{"Which command is running?"} - F -->|`pnpm e2e`| G["Run config default tags: not @fresh and not @skip"] - F -->|`pnpm e2e:full*`| H["Override tags to not @skip"] + F -->|`pnpm -C e2e e2e`| G["Run config default tags: not @fresh and not @skip"] + F -->|`pnpm -C e2e e2e:full*`| H["Override tags to not @skip"] G --> I["Per-scenario BrowserContext from shared browser"] H --> I I --> J["Failure artifacts written to cucumber-report/artifacts"] @@ -99,7 +99,7 @@ Behavior depends on instance state: - uninitialized instance: completes install and stores authenticated state - initialized instance: signs in and reuses authenticated state -Because of that, the `@fresh` install scenario only runs in the `pnpm e2e:full*` flows. The default `pnpm e2e*` flows exclude `@fresh` via Cucumber config tags so they can be re-run against an already initialized instance. +Because of that, the `@fresh` install scenario only runs in the `pnpm -C e2e e2e:full*` flows. The default `pnpm -C e2e e2e*` flows exclude `@fresh` via Cucumber config tags so they can be re-run against an already initialized instance. Reset all persisted E2E state: @@ -126,7 +126,7 @@ pnpm -C e2e e2e:middleware:up Stop the full middleware stack: ```bash -pnpm e2e:middleware:down +pnpm -C e2e e2e:middleware:down ``` The middleware stack includes: @@ -141,15 +141,15 @@ The middleware stack includes: Fresh install verification: ```bash -pnpm e2e:full +pnpm -C e2e e2e:full ``` Run the Cucumber suite against an already running middleware stack: ```bash -pnpm e2e:middleware:up -pnpm e2e -pnpm e2e:middleware:down +pnpm -C e2e e2e:middleware:up +pnpm -C e2e e2e +pnpm -C e2e e2e:middleware:down ``` Artifacts and diagnostics: diff --git a/e2e/features/step-definitions/apps/share-app.steps.ts b/e2e/features/step-definitions/apps/share-app.steps.ts index d5742bdaa8..3ec038b065 100644 --- a/e2e/features/step-definitions/apps/share-app.steps.ts +++ b/e2e/features/step-definitions/apps/share-app.steps.ts @@ -40,7 +40,7 @@ Then('the shared app page should be accessible', async function (this: DifyWorld When('I run the shared workflow app', async function (this: DifyWorld) { const page = this.getPage() - const runButton = page.getByTestId('run-button') + const runButton = page.getByRole('button', { name: 'Execute' }) await expect(runButton).toBeEnabled({ timeout: 15_000 }) await runButton.click() diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 2326e92d2f..46277d3349 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -159,21 +159,11 @@ "count": 5 } }, - "web/app/account/(commonLayout)/delete-account/components/feed-back.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/account/(commonLayout)/delete-account/components/verify-email.tsx": { "react/set-state-in-effect": { "count": 1 } }, - "web/app/account/(commonLayout)/delete-account/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/account/oauth/authorize/layout.tsx": { "ts/no-explicit-any": { "count": 1 @@ -202,18 +192,10 @@ "count": 1 } }, - "web/app/components/app/annotation/add-annotation-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/annotation/batch-add-annotation-modal/index.tsx": { "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "react-refresh/only-export-components": { "count": 1 }, @@ -235,11 +217,6 @@ "count": 1 } }, - "web/app/components/app/annotation/edit-annotation-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/annotation/header-opts/index.tsx": { "ts/no-explicit-any": { "count": 1 @@ -262,9 +239,6 @@ "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 5 }, @@ -277,11 +251,6 @@ "count": 4 } }, - "web/app/components/app/app-publisher/version-info-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/configuration/base/var-highlight/index.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -293,9 +262,6 @@ } }, "web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -311,9 +277,6 @@ } }, "web/app/components/app/configuration/config-var/config-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 4 } @@ -337,9 +300,6 @@ } }, "web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react-hooks/exhaustive-deps": { "count": 1 }, @@ -356,9 +316,6 @@ } }, "web/app/components/app/configuration/config/automatic/get-automatic-res.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 4 }, @@ -387,9 +344,6 @@ } }, "web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 4 }, @@ -397,16 +351,6 @@ "count": 2 } }, - "web/app/components/app/configuration/configuration-view.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/app/configuration/dataset-config/card-item/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/configuration/dataset-config/index.tsx": { "ts/no-explicit-any": { "count": 1 @@ -418,9 +362,6 @@ } }, "web/app/components/app/configuration/dataset-config/params-config/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 1 } @@ -494,26 +435,10 @@ "count": 1 } }, - "web/app/components/app/create-app-modal/index.tsx": { - "react/set-state-in-effect": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/app/create-from-dsl-modal/dsl-confirm-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/create-from-dsl-modal/index.tsx": { "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "react-refresh/only-export-components": { "count": 1 }, @@ -521,11 +446,6 @@ "count": 2 } }, - "web/app/components/app/duplicate-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/log/filter.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -537,9 +457,6 @@ } }, "web/app/components/app/log/list.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 6 }, @@ -561,9 +478,6 @@ } }, "web/app/components/app/switch-app-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 1 } @@ -584,9 +498,6 @@ } }, "web/app/components/app/workflow-log/list.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 2 } @@ -881,11 +792,6 @@ "count": 3 } }, - "web/app/components/base/content-dialog/index.stories.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/base/date-and-time-picker/hooks.ts": { "react/no-unnecessary-use-prefix": { "count": 2 @@ -901,26 +807,6 @@ "count": 1 } }, - "web/app/components/base/dialog/index.stories.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, - "web/app/components/base/drawer-plus/index.stories.tsx": { - "react/component-hook-factories": { - "count": 1 - } - }, - "web/app/components/base/drawer-plus/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/base/emoji-picker/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/base/error-boundary/index.tsx": { "react-refresh/only-export-components": { "count": 3 @@ -942,11 +828,6 @@ "count": 1 } }, - "web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx": { "ts/no-explicit-any": { "count": 3 @@ -973,9 +854,6 @@ } }, "web/app/components/base/features/new-feature-panel/moderation/moderation-setting-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 2 } @@ -1031,11 +909,6 @@ "count": 3 } }, - "web/app/components/base/float-right-container/index.tsx": { - "no-restricted-imports": { - "count": 2 - } - }, "web/app/components/base/form/components/base/base-form.tsx": { "ts/no-explicit-any": { "count": 6 @@ -1051,14 +924,6 @@ "count": 1 } }, - "web/app/components/base/form/components/field/variable-or-constant-input.tsx": { - "no-console": { - "count": 2 - }, - "ts/no-explicit-any": { - "count": 2 - } - }, "web/app/components/base/form/components/field/variable-selector.tsx": { "no-console": { "count": 1 @@ -1464,11 +1329,6 @@ "count": 9 } }, - "web/app/components/base/markdown-blocks/form.tsx": { - "erasable-syntax-only/enums": { - "count": 3 - } - }, "web/app/components/base/markdown-blocks/index.ts": { "no-barrel-files/no-barrel-files": { "count": 10 @@ -1552,16 +1412,6 @@ "count": 1 } }, - "web/app/components/base/modal-like-wrap/index.stories.tsx": { - "no-console": { - "count": 3 - } - }, - "web/app/components/base/modal/index.stories.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/base/new-audio-button/index.tsx": { "ts/no-explicit-any": { "count": 1 @@ -1605,7 +1455,12 @@ }, "web/app/components/base/prompt-editor/index.tsx": { "ts/no-explicit-any": { - "count": 4 + "count": 3 + } + }, + "web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx": { + "no-restricted-imports": { + "count": 1 } }, "web/app/components/base/prompt-editor/plugins/component-picker-block/menu.tsx": { @@ -1693,8 +1548,8 @@ } }, "web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx": { - "ts/no-explicit-any": { - "count": 2 + "no-restricted-imports": { + "count": 1 } }, "web/app/components/base/prompt-editor/plugins/update-block.tsx": { @@ -1848,11 +1703,6 @@ "count": 4 } }, - "web/app/components/billing/annotation-full/modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/billing/billing-page/__tests__/index.spec.tsx": { "ts/no-explicit-any": { "count": 4 @@ -1886,11 +1736,6 @@ "count": 4 } }, - "web/app/components/billing/upgrade-btn/index.tsx": { - "ts/no-explicit-any": { - "count": 3 - } - }, "web/app/components/datasets/common/image-previewer/index.tsx": { "no-irregular-whitespace": { "count": 1 @@ -1916,11 +1761,6 @@ "count": 1 } }, - "web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/dsl-confirm-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/hooks/use-dsl-import.ts": { "erasable-syntax-only/enums": { "count": 1 @@ -1929,9 +1769,6 @@ "web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx": { "no-barrel-files/no-barrel-files": { "count": 1 - }, - "no-restricted-imports": { - "count": 1 } }, "web/app/components/datasets/create-from-pipeline/list/template-card/details/types.ts": { @@ -1939,16 +1776,6 @@ "count": 1 } }, - "web/app/components/datasets/create-from-pipeline/list/template-card/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/create/file-preview/index.tsx": { "react/set-state-in-effect": { "count": 1 @@ -2003,11 +1830,6 @@ "count": 1 } }, - "web/app/components/datasets/create/stop-embedding-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/create/website/firecrawl/index.tsx": { "no-console": { "count": 1 @@ -2066,11 +1888,6 @@ "count": 4 } }, - "web/app/components/datasets/documents/components/rename-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/__tests__/index.spec.tsx": { "erasable-syntax-only/enums": { "count": 1 @@ -2141,11 +1958,6 @@ "count": 1 } }, - "web/app/components/datasets/documents/detail/completed/common/regeneration-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/documents/detail/completed/components/segment-list-content.tsx": { "ts/no-non-null-asserted-optional-chain": { "count": 1 @@ -2220,30 +2032,27 @@ "count": 1 } }, + "web/app/components/datasets/formatted-text/flavours/edit-slice.tsx": { + "no-restricted-imports": { + "count": 2 + } + }, + "web/app/components/datasets/formatted-text/flavours/preview-slice.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/datasets/formatted-text/flavours/type.ts": { "ts/no-empty-object-type": { "count": 1 } }, - "web/app/components/datasets/hit-testing/components/chunk-detail-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/datasets/hit-testing/components/result-item-external.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/hit-testing/components/score.tsx": { "unicorn/prefer-number-properties": { "count": 1 } }, "web/app/components/datasets/hit-testing/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/unsupported-syntax": { "count": 1 } @@ -2253,11 +2062,6 @@ "count": 2 } }, - "web/app/components/datasets/metadata/edit-metadata-batch/modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/metadata/hooks/use-edit-dataset-metadata.ts": { "react/set-state-in-effect": { "count": 1 @@ -2271,47 +2075,19 @@ "count": 2 } }, - "web/app/components/datasets/metadata/metadata-dataset/create-content.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/datasets/metadata/metadata-dataset/create-metadata-modal.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.tsx": { - "no-restricted-imports": { - "count": 2 - } - }, - "web/app/components/datasets/metadata/metadata-dataset/select-metadata-modal.tsx": { - "erasable-syntax-only/enums": { - "count": 1 - } - }, "web/app/components/datasets/metadata/types.ts": { "erasable-syntax-only/enums": { "count": 2 } }, - "web/app/components/datasets/rename-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/settings/chunk-structure/types.ts": { "erasable-syntax-only/enums": { "count": 1 } }, "web/app/components/develop/code.tsx": { - "ts/no-empty-object-type": { - "count": 1 - }, "ts/no-explicit-any": { - "count": 9 + "count": 7 } }, "web/app/components/develop/md.tsx": { @@ -2322,16 +2098,6 @@ "count": 2 } }, - "web/app/components/develop/secret-key/secret-key-generate.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/develop/secret-key/secret-key-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/explore/banner/banner-item.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -2353,11 +2119,6 @@ "count": 1 } }, - "web/app/components/explore/try-app/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/explore/try-app/tab.tsx": { "erasable-syntax-only/enums": { "count": 1 @@ -2437,16 +2198,6 @@ "count": 1 } }, - "web/app/components/header/account-about/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/header/account-setting/api-based-extension-page/modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/header/account-setting/data-source-page-new/card.tsx": { "ts/no-explicit-any": { "count": 2 @@ -2516,6 +2267,11 @@ "count": 4 } }, + "web/app/components/header/account-setting/model-provider-page/model-auth/authorized/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts": { "no-barrel-files/no-barrel-files": { "count": 6 @@ -2587,9 +2343,6 @@ } }, "web/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 1 }, @@ -2631,9 +2384,6 @@ "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "react-refresh/only-export-components": { "count": 1 } @@ -2649,9 +2399,6 @@ } }, "web/app/components/plugins/install-plugin/install-from-github/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 3 } @@ -2661,21 +2408,6 @@ "count": 1 } }, - "web/app/components/plugins/install-plugin/install-from-local-package/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/plugins/install-plugin/install-from-local-package/steps/uploading.tsx": { - "ts/no-explicit-any": { - "count": 2 - } - }, - "web/app/components/plugins/install-plugin/install-from-marketplace/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/marketplace/hooks.ts": { "@tanstack/query/exhaustive-deps": { "count": 1 @@ -2686,6 +2418,11 @@ "count": 1 } }, + "web/app/components/plugins/plugin-auth/authorized/index.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/plugins/plugin-auth/authorized/item.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2761,18 +2498,10 @@ } }, "web/app/components/plugins/plugin-detail-panel/endpoint-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 7 } }, - "web/app/components/plugins/plugin-detail-panel/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/plugin-detail-panel/model-list.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2788,15 +2517,7 @@ "count": 1 } }, - "web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/plugins/plugin-detail-panel/strategy-detail.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 2 } @@ -2839,15 +2560,22 @@ "count": 7 } }, + "web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-base-form.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/index.ts": { "no-barrel-files/no-barrel-files": { "count": 2 } }, - "web/app/components/plugins/plugin-detail-panel/trigger/event-detail-drawer.tsx": { + "web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx": { "no-restricted-imports": { "count": 1 - }, + } + }, + "web/app/components/plugins/plugin-detail-panel/trigger/event-detail-drawer.tsx": { "ts/no-explicit-any": { "count": 5 } @@ -2857,11 +2585,6 @@ "count": 1 } }, - "web/app/components/plugins/plugin-mutation-model/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/plugin-page/context.ts": { "ts/no-explicit-any": { "count": 1 @@ -2877,21 +2600,11 @@ "count": 2 } }, - "web/app/components/plugins/plugin-page/plugin-info.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/reference-setting-modal/auto-update-setting/types.ts": { "erasable-syntax-only/enums": { "count": 2 } }, - "web/app/components/plugins/reference-setting-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/types.ts": { "erasable-syntax-only/enums": { "count": 7 @@ -2974,11 +2687,6 @@ "count": 4 } }, - "web/app/components/rag-pipeline/components/publish-as-knowledge-pipeline-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/rag-pipeline/components/rag-pipeline-children.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2994,16 +2702,6 @@ "count": 2 } }, - "web/app/components/rag-pipeline/components/update-dsl-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/rag-pipeline/components/version-mismatch-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/rag-pipeline/hooks/index.ts": { "no-barrel-files/no-barrel-files": { "count": 9 @@ -3059,11 +2757,6 @@ "count": 1 } }, - "web/app/components/share/text-generation/info-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/share/text-generation/menu-dropdown.tsx": { "react/set-state-in-effect": { "count": 1 @@ -3102,20 +2795,7 @@ "count": 2 } }, - "web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/tools/edit-custom-collection-modal/get-schema.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/tools/edit-custom-collection-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "react/set-state-in-effect": { "count": 4 }, @@ -3124,9 +2804,6 @@ } }, "web/app/components/tools/edit-custom-collection-modal/test-api.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -3136,29 +2813,11 @@ "count": 1 } }, - "web/app/components/tools/mcp/detail/provider-detail.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/tools/mcp/mcp-server-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 5 - } - }, "web/app/components/tools/mcp/mcp-server-param-item.tsx": { "ts/no-explicit-any": { "count": 1 } }, - "web/app/components/tools/mcp/modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/tools/mcp/provider-card.tsx": { "ts/no-explicit-any": { "count": 3 @@ -3169,20 +2828,12 @@ "count": 1 } }, - "web/app/components/tools/provider/detail.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/tools/provider/empty.tsx": { "ts/no-explicit-any": { "count": 1 } }, "web/app/components/tools/setting/build-in/config-credentials.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 3 } @@ -3276,6 +2927,11 @@ "count": 1 } }, + "web/app/components/workflow/block-selector/main.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/workflow/block-selector/market-place-plugin/action.tsx": { "react/set-state-in-effect": { "count": 1 @@ -3291,6 +2947,11 @@ "count": 1 } }, + "web/app/components/workflow/block-selector/tool-picker.tsx": { + "no-restricted-imports": { + "count": 1 + } + }, "web/app/components/workflow/block-selector/tool/tool-list-flat-view/list.tsx": { "ts/no-explicit-any": { "count": 1 @@ -3854,16 +3515,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/http/components/authorization/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/http/components/curl-panel.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/http/components/key-value/key-value-edit/index.tsx": { "ts/no-explicit-any": { "count": 2 @@ -3989,11 +3640,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/knowledge-retrieval/components/dataset-item.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-item.tsx": { "ts/no-explicit-any": { "count": 1 @@ -4045,11 +3691,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx": { "ts/no-explicit-any": { "count": 3 @@ -4170,9 +3811,6 @@ } }, "web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/update.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -4200,30 +3838,9 @@ "count": 9 } }, - "web/app/components/workflow/nodes/question-classifier/components/class-item.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/workflow/nodes/question-classifier/components/class-list.tsx": { "react/set-state-in-effect": { "count": 1 - }, - "react/unsupported-syntax": { - "count": 2 - } - }, - "web/app/components/workflow/nodes/question-classifier/default.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/question-classifier/use-config.ts": { - "react/set-state-in-effect": { - "count": 2 - }, - "ts/no-explicit-any": { - "count": 2 } }, "web/app/components/workflow/nodes/question-classifier/use-single-run-form-params.ts": { @@ -4406,6 +4023,9 @@ } }, "web/app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/component.tsx": { + "no-restricted-imports": { + "count": 1 + }, "react/set-state-in-effect": { "count": 1 } @@ -4416,6 +4036,9 @@ } }, "web/app/components/workflow/operator/add-block.tsx": { + "no-restricted-imports": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -4477,9 +4100,6 @@ } }, "web/app/components/workflow/panel/debug-and-preview/conversation-variable-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 2 } @@ -4507,16 +4127,6 @@ "count": 4 } }, - "web/app/components/workflow/panel/version-history-panel/delete-confirm-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/workflow/panel/version-history-panel/restore-confirm-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/panel/workflow-preview.tsx": { "ts/no-explicit-any": { "count": 2 @@ -4652,9 +4262,6 @@ } }, "web/app/components/workflow/update-dsl-modal.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -4768,11 +4375,6 @@ "count": 1 } }, - "web/app/education-apply/expire-notice-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/education-apply/hooks.ts": { "react/set-state-in-effect": { "count": 5 @@ -4813,11 +4415,6 @@ "count": 1 } }, - "web/app/signin/one-more-step.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/signup/layout.tsx": { "ts/no-explicit-any": { "count": 1 diff --git a/packages/contracts/generated/enterprise/orpc.gen.ts b/packages/contracts/generated/enterprise/orpc.gen.ts index 73eb850001..6b9b76470a 100644 --- a/packages/contracts/generated/enterprise/orpc.gen.ts +++ b/packages/contracts/generated/enterprise/orpc.gen.ts @@ -7,63 +7,6 @@ import { zConsoleSsoOAuth2LoginResponse, zConsoleSsoOidcLoginResponse, zConsoleSsoSamlLoginResponse, - zEnterpriseAppDeployConsoleCancelRuntimeDeploymentBody, - zEnterpriseAppDeployConsoleCancelRuntimeDeploymentPath, - zEnterpriseAppDeployConsoleCancelRuntimeDeploymentResponse, - zEnterpriseAppDeployConsoleCreateAppInstanceBody, - zEnterpriseAppDeployConsoleCreateAppInstanceResponse, - zEnterpriseAppDeployConsoleCreateDeploymentBody, - zEnterpriseAppDeployConsoleCreateDeploymentPath, - zEnterpriseAppDeployConsoleCreateDeploymentResponse, - zEnterpriseAppDeployConsoleCreateDeveloperApiKeyBody, - zEnterpriseAppDeployConsoleCreateDeveloperApiKeyPath, - zEnterpriseAppDeployConsoleCreateDeveloperApiKeyResponse, - zEnterpriseAppDeployConsoleCreateReleaseBody, - zEnterpriseAppDeployConsoleCreateReleasePath, - zEnterpriseAppDeployConsoleCreateReleaseResponse, - zEnterpriseAppDeployConsoleDeleteAppInstancePath, - zEnterpriseAppDeployConsoleDeleteAppInstanceResponse, - zEnterpriseAppDeployConsoleDeleteDeveloperApiKeyPath, - zEnterpriseAppDeployConsoleDeleteDeveloperApiKeyResponse, - zEnterpriseAppDeployConsoleGetAppInstanceAccessPath, - zEnterpriseAppDeployConsoleGetAppInstanceAccessResponse, - zEnterpriseAppDeployConsoleGetAppInstanceOverviewPath, - zEnterpriseAppDeployConsoleGetAppInstanceOverviewResponse, - zEnterpriseAppDeployConsoleGetAppInstanceSettingsPath, - zEnterpriseAppDeployConsoleGetAppInstanceSettingsResponse, - zEnterpriseAppDeployConsoleGetEnvironmentAccessPolicyPath, - zEnterpriseAppDeployConsoleGetEnvironmentAccessPolicyResponse, - zEnterpriseAppDeployConsoleListAppInstancesQuery, - zEnterpriseAppDeployConsoleListAppInstancesResponse, - zEnterpriseAppDeployConsoleListDeploymentBindingOptionsPath, - zEnterpriseAppDeployConsoleListDeploymentBindingOptionsResponse, - zEnterpriseAppDeployConsoleListDeploymentEnvironmentOptionsResponse, - zEnterpriseAppDeployConsoleListReleasesPath, - zEnterpriseAppDeployConsoleListReleasesQuery, - zEnterpriseAppDeployConsoleListReleasesResponse, - zEnterpriseAppDeployConsoleListRuntimeInstancesPath, - zEnterpriseAppDeployConsoleListRuntimeInstancesResponse, - zEnterpriseAppDeployConsolePreviewReleaseBody, - zEnterpriseAppDeployConsolePreviewReleasePath, - zEnterpriseAppDeployConsolePreviewReleaseResponse, - zEnterpriseAppDeployConsoleSearchAccessSubjectsPath, - zEnterpriseAppDeployConsoleSearchAccessSubjectsQuery, - zEnterpriseAppDeployConsoleSearchAccessSubjectsResponse, - zEnterpriseAppDeployConsoleUndeployRuntimeInstanceBody, - zEnterpriseAppDeployConsoleUndeployRuntimeInstancePath, - zEnterpriseAppDeployConsoleUndeployRuntimeInstanceResponse, - zEnterpriseAppDeployConsoleUpdateAccessChannelsBody, - zEnterpriseAppDeployConsoleUpdateAccessChannelsPath, - zEnterpriseAppDeployConsoleUpdateAccessChannelsResponse, - zEnterpriseAppDeployConsoleUpdateAppInstanceBody, - zEnterpriseAppDeployConsoleUpdateAppInstancePath, - zEnterpriseAppDeployConsoleUpdateAppInstanceResponse, - zEnterpriseAppDeployConsoleUpdateDeveloperApiBody, - zEnterpriseAppDeployConsoleUpdateDeveloperApiPath, - zEnterpriseAppDeployConsoleUpdateDeveloperApiResponse, - zEnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyBody, - zEnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyPath, - zEnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyResponse, zWebAppAuthGetGroupSubjectsQuery, zWebAppAuthGetGroupSubjectsResponse, zWebAppAuthGetWebAppAccessModeQuery, @@ -78,344 +21,6 @@ import { zWebAppAuthUpdateWebAppWhitelistSubjectsResponse, } from './zod.gen' -export const listAppInstances = oc - .route({ - inputStructure: 'detailed', - method: 'GET', - operationId: 'EnterpriseAppDeployConsole_ListAppInstances', - path: '/enterprise/app-instances', - tags: ['EnterpriseAppDeployConsole'], - }) - .input(z.object({ query: zEnterpriseAppDeployConsoleListAppInstancesQuery.optional() })) - .output(zEnterpriseAppDeployConsoleListAppInstancesResponse) - -export const createAppInstance = oc - .route({ - inputStructure: 'detailed', - method: 'POST', - operationId: 'EnterpriseAppDeployConsole_CreateAppInstance', - path: '/enterprise/app-instances', - tags: ['EnterpriseAppDeployConsole'], - }) - .input(z.object({ body: zEnterpriseAppDeployConsoleCreateAppInstanceBody })) - .output(zEnterpriseAppDeployConsoleCreateAppInstanceResponse) - -export const deleteAppInstance = oc - .route({ - inputStructure: 'detailed', - method: 'DELETE', - operationId: 'EnterpriseAppDeployConsole_DeleteAppInstance', - path: '/enterprise/app-instances/{appInstanceId}', - tags: ['EnterpriseAppDeployConsole'], - }) - .input(z.object({ params: zEnterpriseAppDeployConsoleDeleteAppInstancePath })) - .output(zEnterpriseAppDeployConsoleDeleteAppInstanceResponse) - -export const updateAppInstance = oc - .route({ - inputStructure: 'detailed', - method: 'PATCH', - operationId: 'EnterpriseAppDeployConsole_UpdateAppInstance', - path: '/enterprise/app-instances/{appInstanceId}', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - body: zEnterpriseAppDeployConsoleUpdateAppInstanceBody, - params: zEnterpriseAppDeployConsoleUpdateAppInstancePath, - }), - ) - .output(zEnterpriseAppDeployConsoleUpdateAppInstanceResponse) - -export const getAppInstanceAccess = oc - .route({ - inputStructure: 'detailed', - method: 'GET', - operationId: 'EnterpriseAppDeployConsole_GetAppInstanceAccess', - path: '/enterprise/app-instances/{appInstanceId}/access', - tags: ['EnterpriseAppDeployConsole'], - }) - .input(z.object({ params: zEnterpriseAppDeployConsoleGetAppInstanceAccessPath })) - .output(zEnterpriseAppDeployConsoleGetAppInstanceAccessResponse) - -export const updateAccessChannels = oc - .route({ - inputStructure: 'detailed', - method: 'PATCH', - operationId: 'EnterpriseAppDeployConsole_UpdateAccessChannels', - path: '/enterprise/app-instances/{appInstanceId}/access-channels', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - body: zEnterpriseAppDeployConsoleUpdateAccessChannelsBody, - params: zEnterpriseAppDeployConsoleUpdateAccessChannelsPath, - }), - ) - .output(zEnterpriseAppDeployConsoleUpdateAccessChannelsResponse) - -export const searchAccessSubjects = oc - .route({ - inputStructure: 'detailed', - method: 'GET', - operationId: 'EnterpriseAppDeployConsole_SearchAccessSubjects', - path: '/enterprise/app-instances/{appInstanceId}/access-subjects:search', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - params: zEnterpriseAppDeployConsoleSearchAccessSubjectsPath, - query: zEnterpriseAppDeployConsoleSearchAccessSubjectsQuery.optional(), - }), - ) - .output(zEnterpriseAppDeployConsoleSearchAccessSubjectsResponse) - -export const createDeveloperApiKey = oc - .route({ - inputStructure: 'detailed', - method: 'POST', - operationId: 'EnterpriseAppDeployConsole_CreateDeveloperApiKey', - path: '/enterprise/app-instances/{appInstanceId}/api-keys', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - body: zEnterpriseAppDeployConsoleCreateDeveloperApiKeyBody, - params: zEnterpriseAppDeployConsoleCreateDeveloperApiKeyPath, - }), - ) - .output(zEnterpriseAppDeployConsoleCreateDeveloperApiKeyResponse) - -export const deleteDeveloperApiKey = oc - .route({ - inputStructure: 'detailed', - method: 'DELETE', - operationId: 'EnterpriseAppDeployConsole_DeleteDeveloperApiKey', - path: '/enterprise/app-instances/{appInstanceId}/api-keys/{apiKeyId}', - tags: ['EnterpriseAppDeployConsole'], - }) - .input(z.object({ params: zEnterpriseAppDeployConsoleDeleteDeveloperApiKeyPath })) - .output(zEnterpriseAppDeployConsoleDeleteDeveloperApiKeyResponse) - -export const listDeploymentBindingOptions = oc - .route({ - inputStructure: 'detailed', - method: 'GET', - operationId: 'EnterpriseAppDeployConsole_ListDeploymentBindingOptions', - path: '/enterprise/app-instances/{appInstanceId}/deployment-binding-options', - tags: ['EnterpriseAppDeployConsole'], - }) - .input(z.object({ params: zEnterpriseAppDeployConsoleListDeploymentBindingOptionsPath })) - .output(zEnterpriseAppDeployConsoleListDeploymentBindingOptionsResponse) - -export const createDeployment = oc - .route({ - inputStructure: 'detailed', - method: 'POST', - operationId: 'EnterpriseAppDeployConsole_CreateDeployment', - path: '/enterprise/app-instances/{appInstanceId}/deployments', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - body: zEnterpriseAppDeployConsoleCreateDeploymentBody, - params: zEnterpriseAppDeployConsoleCreateDeploymentPath, - }), - ) - .output(zEnterpriseAppDeployConsoleCreateDeploymentResponse) - -export const updateDeveloperApi = oc - .route({ - inputStructure: 'detailed', - method: 'PATCH', - operationId: 'EnterpriseAppDeployConsole_UpdateDeveloperApi', - path: '/enterprise/app-instances/{appInstanceId}/developer-api', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - body: zEnterpriseAppDeployConsoleUpdateDeveloperApiBody, - params: zEnterpriseAppDeployConsoleUpdateDeveloperApiPath, - }), - ) - .output(zEnterpriseAppDeployConsoleUpdateDeveloperApiResponse) - -export const getEnvironmentAccessPolicy = oc - .route({ - inputStructure: 'detailed', - method: 'GET', - operationId: 'EnterpriseAppDeployConsole_GetEnvironmentAccessPolicy', - path: '/enterprise/app-instances/{appInstanceId}/environments/{environmentId}/access-policy', - tags: ['EnterpriseAppDeployConsole'], - }) - .input(z.object({ params: zEnterpriseAppDeployConsoleGetEnvironmentAccessPolicyPath })) - .output(zEnterpriseAppDeployConsoleGetEnvironmentAccessPolicyResponse) - -export const updateEnvironmentAccessPolicy = oc - .route({ - inputStructure: 'detailed', - method: 'PUT', - operationId: 'EnterpriseAppDeployConsole_UpdateEnvironmentAccessPolicy', - path: '/enterprise/app-instances/{appInstanceId}/environments/{environmentId}/access-policy', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - body: zEnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyBody, - params: zEnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyPath, - }), - ) - .output(zEnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyResponse) - -export const getAppInstanceOverview = oc - .route({ - inputStructure: 'detailed', - method: 'GET', - operationId: 'EnterpriseAppDeployConsole_GetAppInstanceOverview', - path: '/enterprise/app-instances/{appInstanceId}/overview', - tags: ['EnterpriseAppDeployConsole'], - }) - .input(z.object({ params: zEnterpriseAppDeployConsoleGetAppInstanceOverviewPath })) - .output(zEnterpriseAppDeployConsoleGetAppInstanceOverviewResponse) - -export const listReleases = oc - .route({ - inputStructure: 'detailed', - method: 'GET', - operationId: 'EnterpriseAppDeployConsole_ListReleases', - path: '/enterprise/app-instances/{appInstanceId}/releases', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - params: zEnterpriseAppDeployConsoleListReleasesPath, - query: zEnterpriseAppDeployConsoleListReleasesQuery.optional(), - }), - ) - .output(zEnterpriseAppDeployConsoleListReleasesResponse) - -export const createRelease = oc - .route({ - inputStructure: 'detailed', - method: 'POST', - operationId: 'EnterpriseAppDeployConsole_CreateRelease', - path: '/enterprise/app-instances/{appInstanceId}/releases', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - body: zEnterpriseAppDeployConsoleCreateReleaseBody, - params: zEnterpriseAppDeployConsoleCreateReleasePath, - }), - ) - .output(zEnterpriseAppDeployConsoleCreateReleaseResponse) - -export const previewRelease = oc - .route({ - inputStructure: 'detailed', - method: 'POST', - operationId: 'EnterpriseAppDeployConsole_PreviewRelease', - path: '/enterprise/app-instances/{appInstanceId}/releases:preview', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - body: zEnterpriseAppDeployConsolePreviewReleaseBody, - params: zEnterpriseAppDeployConsolePreviewReleasePath, - }), - ) - .output(zEnterpriseAppDeployConsolePreviewReleaseResponse) - -export const listRuntimeInstances = oc - .route({ - inputStructure: 'detailed', - method: 'GET', - operationId: 'EnterpriseAppDeployConsole_ListRuntimeInstances', - path: '/enterprise/app-instances/{appInstanceId}/runtime-instances', - tags: ['EnterpriseAppDeployConsole'], - }) - .input(z.object({ params: zEnterpriseAppDeployConsoleListRuntimeInstancesPath })) - .output(zEnterpriseAppDeployConsoleListRuntimeInstancesResponse) - -export const cancelRuntimeDeployment = oc - .route({ - inputStructure: 'detailed', - method: 'POST', - operationId: 'EnterpriseAppDeployConsole_CancelRuntimeDeployment', - path: '/enterprise/app-instances/{appInstanceId}/runtime-instances/{runtimeInstanceId}/deployment:cancel', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - body: zEnterpriseAppDeployConsoleCancelRuntimeDeploymentBody, - params: zEnterpriseAppDeployConsoleCancelRuntimeDeploymentPath, - }), - ) - .output(zEnterpriseAppDeployConsoleCancelRuntimeDeploymentResponse) - -export const undeployRuntimeInstance = oc - .route({ - inputStructure: 'detailed', - method: 'POST', - operationId: 'EnterpriseAppDeployConsole_UndeployRuntimeInstance', - path: '/enterprise/app-instances/{appInstanceId}/runtime-instances/{runtimeInstanceId}:undeploy', - tags: ['EnterpriseAppDeployConsole'], - }) - .input( - z.object({ - body: zEnterpriseAppDeployConsoleUndeployRuntimeInstanceBody, - params: zEnterpriseAppDeployConsoleUndeployRuntimeInstancePath, - }), - ) - .output(zEnterpriseAppDeployConsoleUndeployRuntimeInstanceResponse) - -export const getAppInstanceSettings = oc - .route({ - inputStructure: 'detailed', - method: 'GET', - operationId: 'EnterpriseAppDeployConsole_GetAppInstanceSettings', - path: '/enterprise/app-instances/{appInstanceId}/settings', - tags: ['EnterpriseAppDeployConsole'], - }) - .input(z.object({ params: zEnterpriseAppDeployConsoleGetAppInstanceSettingsPath })) - .output(zEnterpriseAppDeployConsoleGetAppInstanceSettingsResponse) - -export const listDeploymentEnvironmentOptions = oc - .route({ - inputStructure: 'detailed', - method: 'GET', - operationId: 'EnterpriseAppDeployConsole_ListDeploymentEnvironmentOptions', - path: '/enterprise/deployment-environment-options', - tags: ['EnterpriseAppDeployConsole'], - }) - .output(zEnterpriseAppDeployConsoleListDeploymentEnvironmentOptionsResponse) - -export const enterpriseAppDeployConsole = { - listAppInstances, - createAppInstance, - deleteAppInstance, - updateAppInstance, - getAppInstanceAccess, - updateAccessChannels, - searchAccessSubjects, - createDeveloperApiKey, - deleteDeveloperApiKey, - listDeploymentBindingOptions, - createDeployment, - updateDeveloperApi, - getEnvironmentAccessPolicy, - updateEnvironmentAccessPolicy, - getAppInstanceOverview, - listReleases, - createRelease, - previewRelease, - listRuntimeInstances, - cancelRuntimeDeployment, - undeployRuntimeInstance, - getAppInstanceSettings, - listDeploymentEnvironmentOptions, -} - export const oAuth2Login = oc .route({ inputStructure: 'detailed', @@ -528,7 +133,6 @@ export const webAppAuth = { } export const contract = { - enterpriseAppDeployConsole, consoleSso, webAppAuth, } diff --git a/packages/contracts/generated/enterprise/types.gen.ts b/packages/contracts/generated/enterprise/types.gen.ts index 56228f2738..b747c4baa8 100644 --- a/packages/contracts/generated/enterprise/types.gen.ts +++ b/packages/contracts/generated/enterprise/types.gen.ts @@ -4,46 +4,6 @@ export type ClientOptions = { baseUrl: `${string}://${string}` | (string & {}) } -export type AccessChannels = { - enabled?: boolean - webappRows?: Array - cli?: CliAccess -} - -export type AccessModeOption = { - mode?: string - label?: string - disabled?: boolean - selected?: boolean -} - -export type AccessPolicyDetail = { - accessMode?: string - subjects?: Array - options?: Array -} - -export type AccessStatus = { - accessChannelsEnabled?: boolean - webappUrl?: string - cliUrl?: string - developerApiEnabled?: boolean - apiKeyCount?: number -} - -export type AccessSubject = { - subjectId?: string - subjectType?: string -} - -export type AccessSubjectDisplay = { - id?: string - subjectType?: string - name?: string - avatarUrl?: string - memberCount?: string -} - export type Account = { id?: string email?: string @@ -70,104 +30,9 @@ export type AccountInWorkspace = { role?: string } -export type AckDeploymentReply = { - accepted?: boolean - newVersion?: string -} - -export type AckDeploymentReq = { - deploymentId?: string - instanceId?: string - expectedVersion?: string - status?: string - observedReleaseId?: string - lastError?: LastError -} - -export type AppInstanceBasicInfo = { +export type AddGroupAppsRequest = { id?: string - name?: string - description?: string - sourceAppId?: string - sourceAppName?: string - mode?: string - createdAt?: string -} - -export type AppInstanceCard = { - id?: string - name?: string - icon?: string - mode?: string - sourceAppName?: string - statuses?: Array - lastDeployedAt?: string -} - -export type AppRunnerBatchRuntimeArtifactReply = { - results?: Array -} - -export type AppRunnerBatchRuntimeArtifactRequest = { - artifacts?: Array -} - -export type AppRunnerBootstrapAssignment = { - appId?: string - environmentId?: string - workflowId?: string - instanceId?: string - workspaceId?: string - instanceVersion?: string - bindingSnapshotVersion?: string - executionTokenVersion?: string - executionToken?: string - releaseId?: string -} - -export type AppRunnerBootstrapReply = { - runnerId?: string - assignmentRevision?: string - assignments?: Array -} - -export type AppRunnerBootstrapRequest = { - runner?: AppRunnerRunnerInfo -} - -export type AppRunnerRunnerInfo = { - hostname?: string -} - -export type AppRunnerRuntimeArtifactReply = { - dslYaml?: string - bindingSnapshotVersion?: string - bindingSnapshot?: { - [key: string]: unknown - } -} - -export type AppRunnerRuntimeArtifactRequest = { - instanceId?: string - releaseId?: string - bindingSnapshotVersion?: string -} - -export type AppRunnerRuntimeArtifactResult = { - instanceId?: string - releaseId?: string - artifact?: AppRunnerRuntimeArtifactReply - errorCode?: string - errorMessage?: string -} - -export type AppRunnerTokenExchangeReply = { - accessToken?: string - expiresAt?: string -} - -export type AppRunnerTokenExchangeRequest = { - joinToken?: string + app_ids?: Array } export type AuthSettingsReply = { @@ -193,15 +58,6 @@ export type AuthSettingsReq = { ssoSettings?: SsoSettings } -export type BootstrapProgress = { - currentStep?: string - completedSteps?: Array - attemptCount?: number - lastAttemptAt?: string - lastErrorCode?: string - lastErrorMessage?: string -} - export type BrandingInfo = { enabled?: boolean applicationTitle?: string @@ -210,15 +66,6 @@ export type BrandingInfo = { favicon?: string } -export type CancelRuntimeDeploymentReply = { - status?: string -} - -export type CancelRuntimeDeploymentReq = { - appInstanceId?: string - runtimeInstanceId?: string -} - export type CheckPasswordStatusReply = { requirePasswordChange?: boolean changeReason?: number @@ -230,82 +77,10 @@ export type ClearDefaultWorkspaceReply = { [key: string]: unknown } -export type CliAccess = { - url?: string -} - -export type ConsoleEnvironment = { - id?: string - name?: string - runtime?: string - type?: string - status?: string -} - -export type ConsoleRelease = { - id?: string - name?: string - shortCommitId?: string - createdAt?: string -} - -export type ConsoleUser = { - id?: string - name?: string -} - -export type CreateAppInstanceReply = { - appInstanceId?: string - initialRelease?: ConsoleRelease -} - -export type CreateAppInstanceReq = { - sourceAppId?: string - name?: string - description?: string -} - export type CreateBearerTokenResponse = { token?: string } -export type CreateDeploymentReply = { - runtimeInstanceId?: string - deploymentId?: string - status?: string -} - -export type CreateDeploymentReq = { - appInstanceId?: string - environmentId?: string - releaseId?: string - bindings?: Array -} - -export type CreateDeveloperApiKeyReply = { - apiKey?: DeveloperApiKeyRow - token?: string -} - -export type CreateDeveloperApiKeyReq = { - appInstanceId?: string - environmentId?: string - name?: string -} - -export type CreateEnvironmentReply = { - environment?: Environment -} - -export type CreateEnvironmentReq = { - name?: string - description?: string - mode?: number - backend?: number - k8s?: K8sEnvironmentConfig - host?: HostEnvironmentConfig -} - export type CreateMemberReply = { id?: string password?: string @@ -329,12 +104,7 @@ export type CreateNewGroupsRes = { groups?: Array } -export type CreateReleaseReply = { - release?: ConsoleRelease -} - -export type CreateReleaseReq = { - appInstanceId?: string +export type CreateResourceGroupRequest = { name?: string description?: string } @@ -394,27 +164,10 @@ export type DashboardSsosamlLoginReply = { url?: string } -export type DeleteAppInstanceReply = { - [key: string]: unknown -} - -export type DeleteDeveloperApiKeyReply = { - [key: string]: unknown -} - -export type DeleteEnvironmentReply = { - [key: string]: unknown -} - export type DeleteGroupsRes = { message?: string } -export type DeleteGuard = { - canDelete?: boolean - disabledReason?: string -} - export type DeleteMemberReply = { account?: Account } @@ -431,70 +184,6 @@ export type DeleteWorkspaceReply = { [key: string]: unknown } -export type DeployedEnvironment = { - environmentId?: string - environmentName?: string -} - -export type DeploymentBindingOptionSlot = { - slot?: string - kind?: string - label?: string - required?: boolean - candidates?: Array - envVarCandidates?: Array -} - -export type DeploymentCredentialOption = { - credentialId?: string - displayName?: string - pluginId?: string - pluginName?: string - pluginVersion?: string -} - -export type DeploymentEnvVarOption = { - envVarId?: string - name?: string - valueType?: string - displayValue?: string -} - -export type DeploymentEnvironmentOption = { - id?: string - name?: string - type?: string - backend?: string - status?: string - managedBy?: string - deployable?: boolean - disabledReason?: string -} - -export type DeploymentRuntimeBinding = { - slot?: string - credentialId?: string - envVarId?: string -} - -export type DeploymentStatusRow = { - environment?: ConsoleEnvironment - release?: ConsoleRelease - status?: string -} - -export type DeveloperApiAccess = { - enabled?: boolean - apiKeys?: Array -} - -export type DeveloperApiKeyRow = { - id?: string - name?: string - environment?: ConsoleEnvironment - maskedKey?: string -} - export type EndpointReply = { mode?: number metricsEndpoint?: OtelExporterEndpoint @@ -507,55 +196,6 @@ export type EnterpriseSystemUserSettingReply = { enableEmailPasswordLogin?: boolean } -export type Environment = { - id?: string - name?: string - description?: string - mode?: number - namespace?: string - apiServer?: string - status?: number - statusMessage?: string - bootstrapProgress?: BootstrapProgress - managedBy?: string - createdAt?: string - updatedAt?: string - backend?: number - host?: string -} - -export type EnvironmentAccessRow = { - environment?: ConsoleEnvironment - currentRelease?: ConsoleRelease - accessMode?: string - accessModeLabel?: string - hint?: string -} - -export type EnvironmentFilter = { - id?: string - name?: string - kind?: string -} - -export type GetAppInstanceAccessReply = { - permissions?: Array - accessChannels?: AccessChannels - developerApi?: DeveloperApiAccess -} - -export type GetAppInstanceOverviewReply = { - instance?: AppInstanceBasicInfo - deployments?: Array - access?: AccessStatus -} - -export type GetAppInstanceSettingsReply = { - name?: string - description?: string - deleteGuard?: DeleteGuard -} - export type GetBearerTokenResponse = { maskedToken?: string } @@ -571,14 +211,6 @@ export type GetDefaultWorkspaceReply = { workspace?: Workspace } -export type GetEnvironmentAccessPolicyReply = { - policy?: AccessPolicyDetail -} - -export type GetEnvironmentReply = { - environment?: Environment -} - export type GetGroupSubjectsRes = { subjects?: Array } @@ -587,15 +219,6 @@ export type GetGroupsRes = { groups?: Array } -export type GetInstanceReply = { - instanceId?: string - status?: string - desiredReleaseId?: string - observedReleaseId?: string - currentDeploymentId?: string - version?: string -} - export type GetJoinedGroupsRes = { groups?: Array } @@ -652,16 +275,22 @@ export type GetWorkspaceReply = { workspace?: Workspace } +export type GroupAppItem = { + app_id?: string + app_name?: string + workspace_id?: string + workspace_name?: string + app_status?: number + token_usage?: string + rpm?: string + concurrency?: string +} + export type HealthzReply = { message?: string status?: string } -export type HostEnvironmentConfig = { - machineId?: string - joinTokenHash?: string -} - export type InfoConfigReply = { SSOEnforcedForSignin?: boolean SSOEnforcedForSigninProtocol?: string @@ -677,6 +306,11 @@ export type InfoConfigReply = { PluginInstallationPermission?: PluginInstallationPermissionInfo } +export type InnerAdmission = { + marker?: string + concurrencyGroupIds?: Array +} + export type InnerBatchGetWebAppAccessModesByIdReq = { appIds?: Array } @@ -698,42 +332,10 @@ export type InnerBatchIsUserAllowedToAccessWebAppRes = { } } -export type InnerCheckAppDeployAccessReply = { - allowed?: boolean - matchedPolicyId?: string - matchedScopeType?: string - reason?: string - cacheTtlSeconds?: number -} - -export type InnerCheckAppDeployAccessReq = { - appInstanceId?: string - environmentId?: string - principalType?: string - principalId?: string -} - export type InnerCleanAppRes = { message?: string } -export type InnerGetTokenRouteReply = { - environmentId?: string - namespace?: string - serviceName?: string - servicePort?: number - environmentStatus?: string - appId?: string - tenantId?: string - instanceId?: string - observedReleaseId?: string - instanceStatus?: string -} - -export type InnerGetTokenRouteReq = { - token?: string -} - export type InnerGetWebAppAccessModeByCodeRes = { accessMode?: string } @@ -742,10 +344,34 @@ export type InnerGetWebAppAccessModeByIdRes = { accessMode?: string } +export type InnerGroupConfig = { + id?: string + enabled?: boolean + membershipId?: string + limits?: Array +} + export type InnerIsUserAllowedToAccessWebAppRes = { result?: boolean } +export type InnerReleaseAdmissionRequest = { + admission?: InnerAdmission +} + +export type InnerReleaseAdmissionResponse = { + [key: string]: unknown +} + +export type InnerResolveResponse = { + appId?: string + groups?: Array + blocked?: boolean + blockGroupId?: string + blockReason?: string + admission?: InnerAdmission +} + export type InnerTryAddAccountToDefaultWorkspaceReply = { workspaceId?: string joined?: boolean @@ -770,20 +396,6 @@ export type JoinWorkspaceReq = { role?: string } -export type K8sEnvironmentConfig = { - namespace?: string - apiServer?: string - caBundle?: string - bearerToken?: string -} - -export type LastError = { - phase?: string - code?: string - message?: string - releaseId?: string -} - export type LicenseInfo = { uuid?: string expiredAt?: string @@ -798,28 +410,21 @@ export type LicenseStatus = { workspaces?: ResourceQuota } +export type LimitConfig = { + type?: number + threshold?: string + action?: number + reached?: boolean +} + export type LimitFields = { workspaceMembers?: number workspaces?: ResourceQuota } -export type ListAppInstancesReply = { - filters?: Array - data?: Array - pagination?: Pagination -} - -export type ListDeploymentBindingOptionsReply = { - slots?: Array -} - -export type ListDeploymentEnvironmentOptionsReply = { - environments?: Array -} - -export type ListEnvironmentsReply = { - data?: Array - pagination?: Pagination +export type ListGroupAppsResponse = { + items?: Array + total?: string } export type ListMembersReply = { @@ -827,13 +432,9 @@ export type ListMembersReply = { pagination?: Pagination } -export type ListReleasesReply = { - data?: Array - pagination?: Pagination -} - -export type ListRuntimeInstancesReply = { - data?: Array +export type ListResourceGroupsResponse = { + items?: Array + total?: string } export type ListSecretKeysReply = { @@ -981,31 +582,6 @@ export type PluginInstallationSettingsReply = { restrictToMarketplaceOnly?: boolean } -export type PreviewReleaseReply = { - release?: ConsoleRelease - bindings?: Array -} - -export type PreviewReleaseReq = { - appInstanceId?: string - releaseId?: string -} - -export type ReleaseRow = { - id?: string - name?: string - createdAt?: string - createdBy?: ConsoleUser - deployedTo?: Array -} - -export type ReleaseRuntimeBinding = { - kind?: string - label?: string - displayValue?: string - valueType?: string -} - export type ResetMemberPasswordReply = { id?: string password?: string @@ -1034,21 +610,35 @@ export type ResetUserPasswordReq = { id?: string } -export type ResolveCredentialsReply = { - resolved?: Array +export type ResourceGroupDetail = { + id?: string + name?: string + description?: string + enabled?: boolean + rpm_limit?: number + rpm_action?: number + concurrency_limit?: number + concurrency_action?: number + token_quota?: string + token_action?: number + created_at?: string + updated_at?: string } -export type ResolveCredentialsReq = { - instanceId?: string - deploymentId?: string - slots?: Array -} - -export type ResolvedCredential = { - slot?: string - credentialId?: string - envVarId?: string - value?: string +export type ResourceGroupItem = { + id?: string + name?: string + description?: string + enabled?: boolean + rpm_limit?: number + concurrency_limit?: number + token_quota?: string + token_usage?: string + app_count?: string + rpm_status?: number + conc_status?: number + created_at?: string + updated_at?: string } export type ResourceQuota = { @@ -1057,36 +647,6 @@ export type ResourceQuota = { enabled?: boolean } -export type RetryEnvironmentReply = { - environment?: Environment -} - -export type RetryEnvironmentReq = { - id?: string -} - -export type RuntimeEndpoints = { - run?: string - health?: string -} - -export type RuntimeInstanceDetail = { - deploymentName?: string - replicas?: number - runtimeMode?: string - runtimeNote?: string - endpoints?: RuntimeEndpoints - bindings?: Array -} - -export type RuntimeInstanceRow = { - id?: string - environment?: ConsoleEnvironment - status?: string - currentRelease?: ConsoleRelease - detail?: RuntimeInstanceDetail -} - export type SamlConfig = { idpSsoUrl?: string certificate?: string @@ -1119,8 +679,21 @@ export type ScimSettings = { lastSyncTime?: string } -export type SearchAccessSubjectsReply = { - data?: Array +export type SearchAppItem = { + app_id?: string + app_name?: string + workspace_id?: string + workspace_name?: string + app_status?: number + icon?: string + icon_type?: string + icon_background?: string + created_by_name?: string +} + +export type SearchAppsResponse = { + items?: Array + total?: string } export type SearchForWhilteListCandidatesRes = { @@ -1145,11 +718,6 @@ export type SetDefaultWorkspaceReq = { id?: string } -export type StatusCount = { - status?: string - count?: number -} - export type Subject = { subjectId?: string subjectType?: string @@ -1185,42 +753,10 @@ export type TestConnectionReply = { error?: string } -export type TestEnvironmentConnectionReply = { - ok?: boolean - reachableServerVersion?: string - namespaceExists?: boolean - missingPermissions?: Array - error?: string - probedAt?: string -} - -export type TestEnvironmentConnectionReq = { - id?: string -} - export type ToggleEndpointRequest = { enabled?: boolean } -export type UndeployRuntimeInstanceReply = { - deploymentId?: string - status?: string -} - -export type UndeployRuntimeInstanceReq = { - appInstanceId?: string - runtimeInstanceId?: string -} - -export type UpdateAccessChannelsReply = { - accessChannels?: AccessChannels -} - -export type UpdateAccessChannelsReq = { - appInstanceId?: string - enabled?: boolean -} - export type UpdateAccessModeReq = { appId?: string accessMode?: string @@ -1230,16 +766,6 @@ export type UpdateAccessModeRes = { message?: string } -export type UpdateAppInstanceReply = { - appInstanceId?: string -} - -export type UpdateAppInstanceReq = { - appInstanceId?: string - name?: string - description?: string -} - export type UpdateBrandingInfoReq = { enabled?: boolean applicationTitle?: string @@ -1248,36 +774,6 @@ export type UpdateBrandingInfoReq = { favicon?: string } -export type UpdateDeveloperApiReply = { - developerApi?: DeveloperApiAccess -} - -export type UpdateDeveloperApiReq = { - appInstanceId?: string - enabled?: boolean -} - -export type UpdateEnvironmentAccessPolicyReply = { - permission?: EnvironmentAccessRow -} - -export type UpdateEnvironmentAccessPolicyReq = { - appInstanceId?: string - environmentId?: string - accessMode?: string - subjects?: Array -} - -export type UpdateEnvironmentReply = { - environment?: Environment -} - -export type UpdateEnvironmentReq = { - id?: string - name?: string - description?: string -} - export type UpdateGroupSubjectsReq = { groupId?: string subjects?: Array @@ -1358,6 +854,19 @@ export type UpdatePluginInstallationSettingsRequest = { restrictToMarketplaceOnly?: boolean } +export type UpdateResourceGroupRequest = { + id?: string + name?: string + description?: string + enabled?: boolean + rpm_limit?: number + rpm_action?: number + concurrency_limit?: number + concurrency_action?: number + token_quota?: string + token_action?: number +} + export type UpdateUserReply = { account?: AccountDetail } @@ -1410,11 +919,6 @@ export type UpdateWorkspaceReq = { status?: string } -export type WebAppAccessRow = { - environment?: ConsoleEnvironment - url?: string -} - export type WebAppAuthInfo = { allowSso?: boolean allowEmailCodeLogin?: boolean @@ -1459,385 +963,6 @@ export type Pagination = { totalPages?: number } -export type EnterpriseAppDeployConsoleListAppInstancesData = { - body?: never - path?: never - query?: { - environmentId?: string - notDeployed?: boolean - query?: string - pageNumber?: number - resultsPerPage?: number - } - url: '/enterprise/app-instances' -} - -export type EnterpriseAppDeployConsoleListAppInstancesResponses = { - 200: ListAppInstancesReply -} - -export type EnterpriseAppDeployConsoleListAppInstancesResponse - = EnterpriseAppDeployConsoleListAppInstancesResponses[keyof EnterpriseAppDeployConsoleListAppInstancesResponses] - -export type EnterpriseAppDeployConsoleCreateAppInstanceData = { - body: CreateAppInstanceReq - path?: never - query?: never - url: '/enterprise/app-instances' -} - -export type EnterpriseAppDeployConsoleCreateAppInstanceResponses = { - 200: CreateAppInstanceReply -} - -export type EnterpriseAppDeployConsoleCreateAppInstanceResponse - = EnterpriseAppDeployConsoleCreateAppInstanceResponses[keyof EnterpriseAppDeployConsoleCreateAppInstanceResponses] - -export type EnterpriseAppDeployConsoleDeleteAppInstanceData = { - body?: never - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}' -} - -export type EnterpriseAppDeployConsoleDeleteAppInstanceResponses = { - 200: DeleteAppInstanceReply -} - -export type EnterpriseAppDeployConsoleDeleteAppInstanceResponse - = EnterpriseAppDeployConsoleDeleteAppInstanceResponses[keyof EnterpriseAppDeployConsoleDeleteAppInstanceResponses] - -export type EnterpriseAppDeployConsoleUpdateAppInstanceData = { - body: UpdateAppInstanceReq - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}' -} - -export type EnterpriseAppDeployConsoleUpdateAppInstanceResponses = { - 200: UpdateAppInstanceReply -} - -export type EnterpriseAppDeployConsoleUpdateAppInstanceResponse - = EnterpriseAppDeployConsoleUpdateAppInstanceResponses[keyof EnterpriseAppDeployConsoleUpdateAppInstanceResponses] - -export type EnterpriseAppDeployConsoleGetAppInstanceAccessData = { - body?: never - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/access' -} - -export type EnterpriseAppDeployConsoleGetAppInstanceAccessResponses = { - 200: GetAppInstanceAccessReply -} - -export type EnterpriseAppDeployConsoleGetAppInstanceAccessResponse - = EnterpriseAppDeployConsoleGetAppInstanceAccessResponses[keyof EnterpriseAppDeployConsoleGetAppInstanceAccessResponses] - -export type EnterpriseAppDeployConsoleUpdateAccessChannelsData = { - body: UpdateAccessChannelsReq - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/access-channels' -} - -export type EnterpriseAppDeployConsoleUpdateAccessChannelsResponses = { - 200: UpdateAccessChannelsReply -} - -export type EnterpriseAppDeployConsoleUpdateAccessChannelsResponse - = EnterpriseAppDeployConsoleUpdateAccessChannelsResponses[keyof EnterpriseAppDeployConsoleUpdateAccessChannelsResponses] - -export type EnterpriseAppDeployConsoleSearchAccessSubjectsData = { - body?: never - path: { - appInstanceId: string - } - query?: { - keyword?: string - subjectTypes?: Array - } - url: '/enterprise/app-instances/{appInstanceId}/access-subjects:search' -} - -export type EnterpriseAppDeployConsoleSearchAccessSubjectsResponses = { - 200: SearchAccessSubjectsReply -} - -export type EnterpriseAppDeployConsoleSearchAccessSubjectsResponse - = EnterpriseAppDeployConsoleSearchAccessSubjectsResponses[keyof EnterpriseAppDeployConsoleSearchAccessSubjectsResponses] - -export type EnterpriseAppDeployConsoleCreateDeveloperApiKeyData = { - body: CreateDeveloperApiKeyReq - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/api-keys' -} - -export type EnterpriseAppDeployConsoleCreateDeveloperApiKeyResponses = { - 200: CreateDeveloperApiKeyReply -} - -export type EnterpriseAppDeployConsoleCreateDeveloperApiKeyResponse - = EnterpriseAppDeployConsoleCreateDeveloperApiKeyResponses[keyof EnterpriseAppDeployConsoleCreateDeveloperApiKeyResponses] - -export type EnterpriseAppDeployConsoleDeleteDeveloperApiKeyData = { - body?: never - path: { - appInstanceId: string - apiKeyId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/api-keys/{apiKeyId}' -} - -export type EnterpriseAppDeployConsoleDeleteDeveloperApiKeyResponses = { - 200: DeleteDeveloperApiKeyReply -} - -export type EnterpriseAppDeployConsoleDeleteDeveloperApiKeyResponse - = EnterpriseAppDeployConsoleDeleteDeveloperApiKeyResponses[keyof EnterpriseAppDeployConsoleDeleteDeveloperApiKeyResponses] - -export type EnterpriseAppDeployConsoleListDeploymentBindingOptionsData = { - body?: never - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/deployment-binding-options' -} - -export type EnterpriseAppDeployConsoleListDeploymentBindingOptionsResponses = { - 200: ListDeploymentBindingOptionsReply -} - -export type EnterpriseAppDeployConsoleListDeploymentBindingOptionsResponse - = EnterpriseAppDeployConsoleListDeploymentBindingOptionsResponses[keyof EnterpriseAppDeployConsoleListDeploymentBindingOptionsResponses] - -export type EnterpriseAppDeployConsoleCreateDeploymentData = { - body: CreateDeploymentReq - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/deployments' -} - -export type EnterpriseAppDeployConsoleCreateDeploymentResponses = { - 200: CreateDeploymentReply -} - -export type EnterpriseAppDeployConsoleCreateDeploymentResponse - = EnterpriseAppDeployConsoleCreateDeploymentResponses[keyof EnterpriseAppDeployConsoleCreateDeploymentResponses] - -export type EnterpriseAppDeployConsoleUpdateDeveloperApiData = { - body: UpdateDeveloperApiReq - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/developer-api' -} - -export type EnterpriseAppDeployConsoleUpdateDeveloperApiResponses = { - 200: UpdateDeveloperApiReply -} - -export type EnterpriseAppDeployConsoleUpdateDeveloperApiResponse - = EnterpriseAppDeployConsoleUpdateDeveloperApiResponses[keyof EnterpriseAppDeployConsoleUpdateDeveloperApiResponses] - -export type EnterpriseAppDeployConsoleGetEnvironmentAccessPolicyData = { - body?: never - path: { - appInstanceId: string - environmentId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/environments/{environmentId}/access-policy' -} - -export type EnterpriseAppDeployConsoleGetEnvironmentAccessPolicyResponses = { - 200: GetEnvironmentAccessPolicyReply -} - -export type EnterpriseAppDeployConsoleGetEnvironmentAccessPolicyResponse - = EnterpriseAppDeployConsoleGetEnvironmentAccessPolicyResponses[keyof EnterpriseAppDeployConsoleGetEnvironmentAccessPolicyResponses] - -export type EnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyData = { - body: UpdateEnvironmentAccessPolicyReq - path: { - appInstanceId: string - environmentId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/environments/{environmentId}/access-policy' -} - -export type EnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyResponses = { - 200: UpdateEnvironmentAccessPolicyReply -} - -export type EnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyResponse - = EnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyResponses[keyof EnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyResponses] - -export type EnterpriseAppDeployConsoleGetAppInstanceOverviewData = { - body?: never - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/overview' -} - -export type EnterpriseAppDeployConsoleGetAppInstanceOverviewResponses = { - 200: GetAppInstanceOverviewReply -} - -export type EnterpriseAppDeployConsoleGetAppInstanceOverviewResponse - = EnterpriseAppDeployConsoleGetAppInstanceOverviewResponses[keyof EnterpriseAppDeployConsoleGetAppInstanceOverviewResponses] - -export type EnterpriseAppDeployConsoleListReleasesData = { - body?: never - path: { - appInstanceId: string - } - query?: { - pageNumber?: number - resultsPerPage?: number - } - url: '/enterprise/app-instances/{appInstanceId}/releases' -} - -export type EnterpriseAppDeployConsoleListReleasesResponses = { - 200: ListReleasesReply -} - -export type EnterpriseAppDeployConsoleListReleasesResponse - = EnterpriseAppDeployConsoleListReleasesResponses[keyof EnterpriseAppDeployConsoleListReleasesResponses] - -export type EnterpriseAppDeployConsoleCreateReleaseData = { - body: CreateReleaseReq - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/releases' -} - -export type EnterpriseAppDeployConsoleCreateReleaseResponses = { - 200: CreateReleaseReply -} - -export type EnterpriseAppDeployConsoleCreateReleaseResponse - = EnterpriseAppDeployConsoleCreateReleaseResponses[keyof EnterpriseAppDeployConsoleCreateReleaseResponses] - -export type EnterpriseAppDeployConsolePreviewReleaseData = { - body: PreviewReleaseReq - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/releases:preview' -} - -export type EnterpriseAppDeployConsolePreviewReleaseResponses = { - 200: PreviewReleaseReply -} - -export type EnterpriseAppDeployConsolePreviewReleaseResponse - = EnterpriseAppDeployConsolePreviewReleaseResponses[keyof EnterpriseAppDeployConsolePreviewReleaseResponses] - -export type EnterpriseAppDeployConsoleListRuntimeInstancesData = { - body?: never - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/runtime-instances' -} - -export type EnterpriseAppDeployConsoleListRuntimeInstancesResponses = { - 200: ListRuntimeInstancesReply -} - -export type EnterpriseAppDeployConsoleListRuntimeInstancesResponse - = EnterpriseAppDeployConsoleListRuntimeInstancesResponses[keyof EnterpriseAppDeployConsoleListRuntimeInstancesResponses] - -export type EnterpriseAppDeployConsoleCancelRuntimeDeploymentData = { - body: CancelRuntimeDeploymentReq - path: { - appInstanceId: string - runtimeInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/runtime-instances/{runtimeInstanceId}/deployment:cancel' -} - -export type EnterpriseAppDeployConsoleCancelRuntimeDeploymentResponses = { - 200: CancelRuntimeDeploymentReply -} - -export type EnterpriseAppDeployConsoleCancelRuntimeDeploymentResponse - = EnterpriseAppDeployConsoleCancelRuntimeDeploymentResponses[keyof EnterpriseAppDeployConsoleCancelRuntimeDeploymentResponses] - -export type EnterpriseAppDeployConsoleUndeployRuntimeInstanceData = { - body: UndeployRuntimeInstanceReq - path: { - appInstanceId: string - runtimeInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/runtime-instances/{runtimeInstanceId}:undeploy' -} - -export type EnterpriseAppDeployConsoleUndeployRuntimeInstanceResponses = { - 200: UndeployRuntimeInstanceReply -} - -export type EnterpriseAppDeployConsoleUndeployRuntimeInstanceResponse - = EnterpriseAppDeployConsoleUndeployRuntimeInstanceResponses[keyof EnterpriseAppDeployConsoleUndeployRuntimeInstanceResponses] - -export type EnterpriseAppDeployConsoleGetAppInstanceSettingsData = { - body?: never - path: { - appInstanceId: string - } - query?: never - url: '/enterprise/app-instances/{appInstanceId}/settings' -} - -export type EnterpriseAppDeployConsoleGetAppInstanceSettingsResponses = { - 200: GetAppInstanceSettingsReply -} - -export type EnterpriseAppDeployConsoleGetAppInstanceSettingsResponse - = EnterpriseAppDeployConsoleGetAppInstanceSettingsResponses[keyof EnterpriseAppDeployConsoleGetAppInstanceSettingsResponses] - -export type EnterpriseAppDeployConsoleListDeploymentEnvironmentOptionsData = { - body?: never - path?: never - query?: never - url: '/enterprise/deployment-environment-options' -} - -export type EnterpriseAppDeployConsoleListDeploymentEnvironmentOptionsResponses = { - 200: ListDeploymentEnvironmentOptionsReply -} - -export type EnterpriseAppDeployConsoleListDeploymentEnvironmentOptionsResponse - = EnterpriseAppDeployConsoleListDeploymentEnvironmentOptionsResponses[keyof EnterpriseAppDeployConsoleListDeploymentEnvironmentOptionsResponses] - export type ConsoleSsoOAuth2LoginData = { body?: never path?: never diff --git a/packages/contracts/generated/enterprise/zod.gen.ts b/packages/contracts/generated/enterprise/zod.gen.ts index 1e7e3d44ae..cef500a906 100644 --- a/packages/contracts/generated/enterprise/zod.gen.ts +++ b/packages/contracts/generated/enterprise/zod.gen.ts @@ -2,44 +2,6 @@ import * as z from 'zod' -export const zAccessModeOption = z.object({ - mode: z.string().optional(), - label: z.string().optional(), - disabled: z.boolean().optional(), - selected: z.boolean().optional(), -}) - -export const zAccessStatus = z.object({ - accessChannelsEnabled: z.boolean().optional(), - webappUrl: z.string().optional(), - cliUrl: z.string().optional(), - developerApiEnabled: z.boolean().optional(), - apiKeyCount: z - .int() - .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) - .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) - .optional(), -}) - -export const zAccessSubject = z.object({ - subjectId: z.string().optional(), - subjectType: z.string().optional(), -}) - -export const zAccessSubjectDisplay = z.object({ - id: z.string().optional(), - subjectType: z.string().optional(), - name: z.string().optional(), - avatarUrl: z.string().optional(), - memberCount: z.string().optional(), -}) - -export const zAccessPolicyDetail = z.object({ - accessMode: z.string().optional(), - subjects: z.array(zAccessSubjectDisplay).optional(), - options: z.array(zAccessModeOption).optional(), -}) - /** * Account represents a basic user account */ @@ -75,101 +37,9 @@ export const zAccountDetail = z.object({ groups: z.array(zAccountDetailGroup).optional(), }) -export const zAckDeploymentReply = z.object({ - accepted: z.boolean().optional(), - newVersion: z.string().optional(), -}) - -export const zAppInstanceBasicInfo = z.object({ +export const zAddGroupAppsRequest = z.object({ id: z.string().optional(), - name: z.string().optional(), - description: z.string().optional(), - sourceAppId: z.string().optional(), - sourceAppName: z.string().optional(), - mode: z.string().optional(), - createdAt: z.iso.datetime().optional(), -}) - -export const zAppRunnerBootstrapAssignment = z.object({ - appId: z.string().optional(), - environmentId: z.string().optional(), - workflowId: z.string().optional(), - instanceId: z.string().optional(), - workspaceId: z.string().optional(), - instanceVersion: z.string().optional(), - bindingSnapshotVersion: z.string().optional(), - executionTokenVersion: z.string().optional(), - executionToken: z.string().optional(), - releaseId: z.string().optional(), -}) - -export const zAppRunnerBootstrapReply = z.object({ - runnerId: z.string().optional(), - assignmentRevision: z.string().optional(), - assignments: z.array(zAppRunnerBootstrapAssignment).optional(), -}) - -export const zAppRunnerRunnerInfo = z.object({ - hostname: z.string().optional(), -}) - -export const zAppRunnerBootstrapRequest = z.object({ - runner: zAppRunnerRunnerInfo.optional(), -}) - -export const zAppRunnerRuntimeArtifactReply = z.object({ - dslYaml: z.string().optional(), - bindingSnapshotVersion: z.string().optional(), - bindingSnapshot: z.record(z.string(), z.unknown()).optional(), -}) - -export const zAppRunnerRuntimeArtifactRequest = z.object({ - instanceId: z.string().optional(), - releaseId: z.string().optional(), - bindingSnapshotVersion: z.string().optional(), -}) - -export const zAppRunnerBatchRuntimeArtifactRequest = z.object({ - artifacts: z.array(zAppRunnerRuntimeArtifactRequest).optional(), -}) - -export const zAppRunnerRuntimeArtifactResult = z.object({ - instanceId: z.string().optional(), - releaseId: z.string().optional(), - artifact: zAppRunnerRuntimeArtifactReply.optional(), - errorCode: z.string().optional(), - errorMessage: z.string().optional(), -}) - -export const zAppRunnerBatchRuntimeArtifactReply = z.object({ - results: z.array(zAppRunnerRuntimeArtifactResult).optional(), -}) - -export const zAppRunnerTokenExchangeReply = z.object({ - accessToken: z.string().optional(), - expiresAt: z.iso.datetime().optional(), -}) - -export const zAppRunnerTokenExchangeRequest = z.object({ - joinToken: z.string().optional(), -}) - -/** - * BootstrapProgress is step-list-agnostic. Reconcilers emit step names as - * strings owned by each executor (e.g. "connectivity", "namespace"), so adding - * or removing steps does not break the API. - */ -export const zBootstrapProgress = z.object({ - currentStep: z.string().optional(), - completedSteps: z.array(z.string()).optional(), - attemptCount: z - .int() - .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) - .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) - .optional(), - lastAttemptAt: z.iso.datetime().optional(), - lastErrorCode: z.string().optional(), - lastErrorMessage: z.string().optional(), + app_ids: z.array(z.string()).optional(), }) export const zBrandingInfo = z.object({ @@ -180,15 +50,6 @@ export const zBrandingInfo = z.object({ favicon: z.string().optional(), }) -export const zCancelRuntimeDeploymentReply = z.object({ - status: z.string().optional(), -}) - -export const zCancelRuntimeDeploymentReq = z.object({ - appInstanceId: z.string().optional(), - runtimeInstanceId: z.string().optional(), -}) - export const zCheckPasswordStatusReply = z.object({ requirePasswordChange: z.boolean().optional(), changeReason: z.int().optional(), @@ -202,57 +63,10 @@ export const zCheckPasswordStatusReply = z.object({ export const zClearDefaultWorkspaceReply = z.record(z.string(), z.unknown()) -export const zCliAccess = z.object({ - url: z.string().optional(), -}) - -export const zConsoleEnvironment = z.object({ - id: z.string().optional(), - name: z.string().optional(), - runtime: z.string().optional(), - type: z.string().optional(), - status: z.string().optional(), -}) - -export const zConsoleRelease = z.object({ - id: z.string().optional(), - name: z.string().optional(), - shortCommitId: z.string().optional(), - createdAt: z.iso.datetime().optional(), -}) - -export const zConsoleUser = z.object({ - id: z.string().optional(), - name: z.string().optional(), -}) - -export const zCreateAppInstanceReply = z.object({ - appInstanceId: z.string().optional(), - initialRelease: zConsoleRelease.optional(), -}) - -export const zCreateAppInstanceReq = z.object({ - sourceAppId: z.string().optional(), - name: z.string().optional(), - description: z.string().optional(), -}) - export const zCreateBearerTokenResponse = z.object({ token: z.string().optional(), }) -export const zCreateDeploymentReply = z.object({ - runtimeInstanceId: z.string().optional(), - deploymentId: z.string().optional(), - status: z.string().optional(), -}) - -export const zCreateDeveloperApiKeyReq = z.object({ - appInstanceId: z.string().optional(), - environmentId: z.string().optional(), - name: z.string().optional(), -}) - export const zCreateMemberReply = z.object({ id: z.string().optional(), password: z.string().optional(), @@ -275,12 +89,7 @@ export const zCreateNewGroupsReq = z.object({ groups: z.array(zCreateNewGroupsReqGroup).optional(), }) -export const zCreateReleaseReply = z.object({ - release: zConsoleRelease.optional(), -}) - -export const zCreateReleaseReq = z.object({ - appInstanceId: z.string().optional(), +export const zCreateResourceGroupRequest = z.object({ name: z.string().optional(), description: z.string().optional(), }) @@ -342,21 +151,10 @@ export const zDashboardSsosamlLoginReply = z.object({ url: z.string().optional(), }) -export const zDeleteAppInstanceReply = z.record(z.string(), z.unknown()) - -export const zDeleteDeveloperApiKeyReply = z.record(z.string(), z.unknown()) - -export const zDeleteEnvironmentReply = z.record(z.string(), z.unknown()) - export const zDeleteGroupsRes = z.object({ message: z.string().optional(), }) -export const zDeleteGuard = z.object({ - canDelete: z.boolean().optional(), - disabledReason: z.string().optional(), -}) - export const zDeleteMemberReply = z.object({ account: zAccount.optional(), }) @@ -371,82 +169,6 @@ export const zDeleteUserReply = z.object({ export const zDeleteWorkspaceReply = z.record(z.string(), z.unknown()) -export const zDeployedEnvironment = z.object({ - environmentId: z.string().optional(), - environmentName: z.string().optional(), -}) - -export const zDeploymentCredentialOption = z.object({ - credentialId: z.string().optional(), - displayName: z.string().optional(), - pluginId: z.string().optional(), - pluginName: z.string().optional(), - pluginVersion: z.string().optional(), -}) - -export const zDeploymentEnvVarOption = z.object({ - envVarId: z.string().optional(), - name: z.string().optional(), - valueType: z.string().optional(), - displayValue: z.string().optional(), -}) - -export const zDeploymentBindingOptionSlot = z.object({ - slot: z.string().optional(), - kind: z.string().optional(), - label: z.string().optional(), - required: z.boolean().optional(), - candidates: z.array(zDeploymentCredentialOption).optional(), - envVarCandidates: z.array(zDeploymentEnvVarOption).optional(), -}) - -export const zDeploymentEnvironmentOption = z.object({ - id: z.string().optional(), - name: z.string().optional(), - type: z.string().optional(), - backend: z.string().optional(), - status: z.string().optional(), - managedBy: z.string().optional(), - deployable: z.boolean().optional(), - disabledReason: z.string().optional(), -}) - -export const zDeploymentRuntimeBinding = z.object({ - slot: z.string().optional(), - credentialId: z.string().optional(), - envVarId: z.string().optional(), -}) - -export const zCreateDeploymentReq = z.object({ - appInstanceId: z.string().optional(), - environmentId: z.string().optional(), - releaseId: z.string().optional(), - bindings: z.array(zDeploymentRuntimeBinding).optional(), -}) - -export const zDeploymentStatusRow = z.object({ - environment: zConsoleEnvironment.optional(), - release: zConsoleRelease.optional(), - status: z.string().optional(), -}) - -export const zDeveloperApiKeyRow = z.object({ - id: z.string().optional(), - name: z.string().optional(), - environment: zConsoleEnvironment.optional(), - maskedKey: z.string().optional(), -}) - -export const zCreateDeveloperApiKeyReply = z.object({ - apiKey: zDeveloperApiKeyRow.optional(), - token: z.string().optional(), -}) - -export const zDeveloperApiAccess = z.object({ - enabled: z.boolean().optional(), - apiKeys: z.array(zDeveloperApiKeyRow).optional(), -}) - /** * System user setting messages */ @@ -456,53 +178,6 @@ export const zEnterpriseSystemUserSettingReply = z.object({ enableEmailPasswordLogin: z.boolean().optional(), }) -export const zEnvironment = z.object({ - id: z.string().optional(), - name: z.string().optional(), - description: z.string().optional(), - mode: z.int().optional(), - namespace: z.string().optional(), - apiServer: z.string().optional(), - status: z.int().optional(), - statusMessage: z.string().optional(), - bootstrapProgress: zBootstrapProgress.optional(), - managedBy: z.string().optional(), - createdAt: z.iso.datetime().optional(), - updatedAt: z.iso.datetime().optional(), - backend: z.int().optional(), - host: z.string().optional(), -}) - -export const zCreateEnvironmentReply = z.object({ - environment: zEnvironment.optional(), -}) - -export const zEnvironmentAccessRow = z.object({ - environment: zConsoleEnvironment.optional(), - currentRelease: zConsoleRelease.optional(), - accessMode: z.string().optional(), - accessModeLabel: z.string().optional(), - hint: z.string().optional(), -}) - -export const zEnvironmentFilter = z.object({ - id: z.string().optional(), - name: z.string().optional(), - kind: z.string().optional(), -}) - -export const zGetAppInstanceOverviewReply = z.object({ - instance: zAppInstanceBasicInfo.optional(), - deployments: z.array(zDeploymentStatusRow).optional(), - access: zAccessStatus.optional(), -}) - -export const zGetAppInstanceSettingsReply = z.object({ - name: z.string().optional(), - description: z.string().optional(), - deleteGuard: zDeleteGuard.optional(), -}) - export const zGetBearerTokenResponse = z.object({ maskedToken: z.string().optional(), }) @@ -513,23 +188,6 @@ export const zGetClusterInfoReply = z.object({ verifyMode: z.string().optional(), }) -export const zGetEnvironmentAccessPolicyReply = z.object({ - policy: zAccessPolicyDetail.optional(), -}) - -export const zGetEnvironmentReply = z.object({ - environment: zEnvironment.optional(), -}) - -export const zGetInstanceReply = z.object({ - instanceId: z.string().optional(), - status: z.string().optional(), - desiredReleaseId: z.string().optional(), - observedReleaseId: z.string().optional(), - currentDeploymentId: z.string().optional(), - version: z.string().optional(), -}) - export const zGetLicenseStatusReply = z.object({ status: z.string().optional(), }) @@ -565,14 +223,25 @@ export const zGetWebAppWhitelistSubjectsResMember = z.object({ avatar: z.string().optional(), }) +export const zGroupAppItem = z.object({ + app_id: z.string().optional(), + app_name: z.string().optional(), + workspace_id: z.string().optional(), + workspace_name: z.string().optional(), + app_status: z.int().optional(), + token_usage: z.string().optional(), + rpm: z.string().optional(), + concurrency: z.string().optional(), +}) + export const zHealthzReply = z.object({ message: z.string().optional(), status: z.string().optional(), }) -export const zHostEnvironmentConfig = z.object({ - machineId: z.string().optional(), - joinTokenHash: z.string().optional(), +export const zInnerAdmission = z.object({ + marker: z.string().optional(), + concurrencyGroupIds: z.array(z.string()).optional(), }) export const zInnerBatchGetWebAppAccessModesByIdReq = z.object({ @@ -592,50 +261,10 @@ export const zInnerBatchIsUserAllowedToAccessWebAppRes = z.object({ permissions: z.record(z.string(), z.boolean()).optional(), }) -export const zInnerCheckAppDeployAccessReply = z.object({ - allowed: z.boolean().optional(), - matchedPolicyId: z.string().optional(), - matchedScopeType: z.string().optional(), - reason: z.string().optional(), - cacheTtlSeconds: z - .int() - .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) - .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) - .optional(), -}) - -export const zInnerCheckAppDeployAccessReq = z.object({ - appInstanceId: z.string().optional(), - environmentId: z.string().optional(), - principalType: z.string().optional(), - principalId: z.string().optional(), -}) - export const zInnerCleanAppRes = z.object({ message: z.string().optional(), }) -export const zInnerGetTokenRouteReply = z.object({ - environmentId: z.string().optional(), - namespace: z.string().optional(), - serviceName: z.string().optional(), - servicePort: z - .int() - .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) - .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) - .optional(), - environmentStatus: z.string().optional(), - appId: z.string().optional(), - tenantId: z.string().optional(), - instanceId: z.string().optional(), - observedReleaseId: z.string().optional(), - instanceStatus: z.string().optional(), -}) - -export const zInnerGetTokenRouteReq = z.object({ - token: z.string().optional(), -}) - export const zInnerGetWebAppAccessModeByCodeRes = z.object({ accessMode: z.string().optional(), }) @@ -648,6 +277,12 @@ export const zInnerIsUserAllowedToAccessWebAppRes = z.object({ result: z.boolean().optional(), }) +export const zInnerReleaseAdmissionRequest = z.object({ + admission: zInnerAdmission.optional(), +}) + +export const zInnerReleaseAdmissionResponse = z.record(z.string(), z.unknown()) + export const zInnerTryAddAccountToDefaultWorkspaceReply = z.object({ workspaceId: z.string().optional(), joined: z.boolean().optional(), @@ -678,48 +313,32 @@ export const zJoinWorkspaceReq = z.object({ role: z.string().optional(), }) -export const zK8sEnvironmentConfig = z.object({ - namespace: z.string().optional(), - apiServer: z.string().optional(), - caBundle: z.string().optional(), - bearerToken: z.string().optional(), +export const zLimitConfig = z.object({ + type: z.int().optional(), + threshold: z.string().optional(), + action: z.int().optional(), + reached: z.boolean().optional(), }) -/** - * Field-level validation only; target (api_server) and RBAC validation happen - * in the bootstrap reconciler. - */ -export const zCreateEnvironmentReq = z.object({ - name: z.string().optional(), - description: z.string().optional(), - mode: z.int().optional(), - backend: z.int().optional(), - k8s: zK8sEnvironmentConfig.optional(), - host: zHostEnvironmentConfig.optional(), +export const zInnerGroupConfig = z.object({ + id: z.string().optional(), + enabled: z.boolean().optional(), + membershipId: z.string().optional(), + limits: z.array(zLimitConfig).optional(), }) -export const zLastError = z.object({ - phase: z.string().optional(), - code: z.string().optional(), - message: z.string().optional(), - releaseId: z.string().optional(), +export const zInnerResolveResponse = z.object({ + appId: z.string().optional(), + groups: z.array(zInnerGroupConfig).optional(), + blocked: z.boolean().optional(), + blockGroupId: z.string().optional(), + blockReason: z.string().optional(), + admission: zInnerAdmission.optional(), }) -export const zAckDeploymentReq = z.object({ - deploymentId: z.string().optional(), - instanceId: z.string().optional(), - expectedVersion: z.string().optional(), - status: z.string().optional(), - observedReleaseId: z.string().optional(), - lastError: zLastError.optional(), -}) - -export const zListDeploymentBindingOptionsReply = z.object({ - slots: z.array(zDeploymentBindingOptionSlot).optional(), -}) - -export const zListDeploymentEnvironmentOptionsReply = z.object({ - environments: z.array(zDeploymentEnvironmentOption).optional(), +export const zListGroupAppsResponse = z.object({ + items: z.array(zGroupAppItem).optional(), + total: z.string().optional(), }) export const zLoginTypesReply = z.object({ @@ -871,31 +490,6 @@ export const zPluginInstallationSettingsReply = z.object({ restrictToMarketplaceOnly: z.boolean().optional(), }) -export const zPreviewReleaseReq = z.object({ - appInstanceId: z.string().optional(), - releaseId: z.string().optional(), -}) - -export const zReleaseRow = z.object({ - id: z.string().optional(), - name: z.string().optional(), - createdAt: z.iso.datetime().optional(), - createdBy: zConsoleUser.optional(), - deployedTo: z.array(zDeployedEnvironment).optional(), -}) - -export const zReleaseRuntimeBinding = z.object({ - kind: z.string().optional(), - label: z.string().optional(), - displayValue: z.string().optional(), - valueType: z.string().optional(), -}) - -export const zPreviewReleaseReply = z.object({ - release: zConsoleRelease.optional(), - bindings: z.array(zReleaseRuntimeBinding).optional(), -}) - export const zResetMemberPasswordReply = z.object({ id: z.string().optional(), password: z.string().optional(), @@ -930,26 +524,56 @@ export const zResetUserPasswordReq = z.object({ id: z.string().optional(), }) -export const zResolveCredentialsReq = z.object({ - instanceId: z.string().optional(), - deploymentId: z.string().optional(), - slots: z.array(z.string()).optional(), +export const zResourceGroupDetail = z.object({ + id: z.string().optional(), + name: z.string().optional(), + description: z.string().optional(), + enabled: z.boolean().optional(), + rpm_limit: z + .int() + .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) + .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) + .optional(), + rpm_action: z.int().optional(), + concurrency_limit: z + .int() + .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) + .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) + .optional(), + concurrency_action: z.int().optional(), + token_quota: z.string().optional(), + token_action: z.int().optional(), + created_at: z.string().optional(), + updated_at: z.string().optional(), }) -/** - * Exactly one of credential_id / env_var_id is populated; model/plugin slots - * carry credential_id (pool A), env_var slots carry env_var_id (pool B). - * See design §4.1. - */ -export const zResolvedCredential = z.object({ - slot: z.string().optional(), - credentialId: z.string().optional(), - envVarId: z.string().optional(), - value: z.string().optional(), +export const zResourceGroupItem = z.object({ + id: z.string().optional(), + name: z.string().optional(), + description: z.string().optional(), + enabled: z.boolean().optional(), + rpm_limit: z + .int() + .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) + .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) + .optional(), + concurrency_limit: z + .int() + .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) + .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) + .optional(), + token_quota: z.string().optional(), + token_usage: z.string().optional(), + app_count: z.string().optional(), + rpm_status: z.int().optional(), + conc_status: z.int().optional(), + created_at: z.string().optional(), + updated_at: z.string().optional(), }) -export const zResolveCredentialsReply = z.object({ - resolved: z.array(zResolvedCredential).optional(), +export const zListResourceGroupsResponse = z.object({ + items: z.array(zResourceGroupItem).optional(), + total: z.string().optional(), }) /** @@ -1002,44 +626,6 @@ export const zGetLicenseReply = z.object({ license: zLicenseInfo.optional(), }) -export const zRetryEnvironmentReply = z.object({ - environment: zEnvironment.optional(), -}) - -export const zRetryEnvironmentReq = z.object({ - id: z.string().optional(), -}) - -export const zRuntimeEndpoints = z.object({ - run: z.string().optional(), - health: z.string().optional(), -}) - -export const zRuntimeInstanceDetail = z.object({ - deploymentName: z.string().optional(), - replicas: z - .int() - .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) - .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) - .optional(), - runtimeMode: z.string().optional(), - runtimeNote: z.string().optional(), - endpoints: zRuntimeEndpoints.optional(), - bindings: z.array(zReleaseRuntimeBinding).optional(), -}) - -export const zRuntimeInstanceRow = z.object({ - id: z.string().optional(), - environment: zConsoleEnvironment.optional(), - status: z.string().optional(), - currentRelease: zConsoleRelease.optional(), - detail: zRuntimeInstanceDetail.optional(), -}) - -export const zListRuntimeInstancesReply = z.object({ - data: z.array(zRuntimeInstanceRow).optional(), -}) - /** * SSO Configuration messages */ @@ -1102,8 +688,21 @@ export const zScimSettings = z.object({ lastSyncTime: z.iso.datetime().optional(), }) -export const zSearchAccessSubjectsReply = z.object({ - data: z.array(zAccessSubjectDisplay).optional(), +export const zSearchAppItem = z.object({ + app_id: z.string().optional(), + app_name: z.string().optional(), + workspace_id: z.string().optional(), + workspace_name: z.string().optional(), + app_status: z.int().optional(), + icon: z.string().optional(), + icon_type: z.string().optional(), + icon_background: z.string().optional(), + created_by_name: z.string().optional(), +}) + +export const zSearchAppsResponse = z.object({ + items: z.array(zSearchAppItem).optional(), + total: z.string().optional(), }) export const zSecretKey = z.object({ @@ -1122,25 +721,6 @@ export const zSetDefaultWorkspaceReq = z.object({ id: z.string().optional(), }) -export const zStatusCount = z.object({ - status: z.string().optional(), - count: z - .int() - .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) - .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) - .optional(), -}) - -export const zAppInstanceCard = z.object({ - id: z.string().optional(), - name: z.string().optional(), - icon: z.string().optional(), - mode: z.string().optional(), - sourceAppName: z.string().optional(), - statuses: z.array(zStatusCount).optional(), - lastDeployedAt: z.iso.datetime().optional(), -}) - export const zSubjectAccountData = z.object({ id: z.string().optional(), name: z.string().optional(), @@ -1214,38 +794,10 @@ export const zTestConnectionReply = z.object({ error: z.string().optional(), }) -export const zTestEnvironmentConnectionReply = z.object({ - ok: z.boolean().optional(), - reachableServerVersion: z.string().optional(), - namespaceExists: z.boolean().optional(), - missingPermissions: z.array(z.string()).optional(), - error: z.string().optional(), - probedAt: z.iso.datetime().optional(), -}) - -export const zTestEnvironmentConnectionReq = z.object({ - id: z.string().optional(), -}) - export const zToggleEndpointRequest = z.object({ enabled: z.boolean().optional(), }) -export const zUndeployRuntimeInstanceReply = z.object({ - deploymentId: z.string().optional(), - status: z.string().optional(), -}) - -export const zUndeployRuntimeInstanceReq = z.object({ - appInstanceId: z.string().optional(), - runtimeInstanceId: z.string().optional(), -}) - -export const zUpdateAccessChannelsReq = z.object({ - appInstanceId: z.string().optional(), - enabled: z.boolean().optional(), -}) - export const zUpdateAccessModeReq = z.object({ appId: z.string().optional(), accessMode: z.string().optional(), @@ -1255,16 +807,6 @@ export const zUpdateAccessModeRes = z.object({ message: z.string().optional(), }) -export const zUpdateAppInstanceReply = z.object({ - appInstanceId: z.string().optional(), -}) - -export const zUpdateAppInstanceReq = z.object({ - appInstanceId: z.string().optional(), - name: z.string().optional(), - description: z.string().optional(), -}) - export const zUpdateBrandingInfoReq = z.object({ enabled: z.boolean().optional(), applicationTitle: z.string().optional(), @@ -1273,36 +815,6 @@ export const zUpdateBrandingInfoReq = z.object({ favicon: z.string().optional(), }) -export const zUpdateDeveloperApiReply = z.object({ - developerApi: zDeveloperApiAccess.optional(), -}) - -export const zUpdateDeveloperApiReq = z.object({ - appInstanceId: z.string().optional(), - enabled: z.boolean().optional(), -}) - -export const zUpdateEnvironmentAccessPolicyReply = z.object({ - permission: zEnvironmentAccessRow.optional(), -}) - -export const zUpdateEnvironmentAccessPolicyReq = z.object({ - appInstanceId: z.string().optional(), - environmentId: z.string().optional(), - accessMode: z.string().optional(), - subjects: z.array(zAccessSubject).optional(), -}) - -export const zUpdateEnvironmentReply = z.object({ - environment: zEnvironment.optional(), -}) - -export const zUpdateEnvironmentReq = z.object({ - id: z.string().optional(), - name: z.string().optional(), - description: z.string().optional(), -}) - export const zUpdateGroupSubjectsReq = z.object({ groupId: z.string().optional(), subjects: z.array(zSubject).optional(), @@ -1386,6 +898,27 @@ export const zUpdatePluginInstallationSettingsRequest = z.object({ restrictToMarketplaceOnly: z.boolean().optional(), }) +export const zUpdateResourceGroupRequest = z.object({ + id: z.string().optional(), + name: z.string().optional(), + description: z.string().optional(), + enabled: z.boolean().optional(), + rpm_limit: z + .int() + .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) + .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) + .optional(), + rpm_action: z.int().optional(), + concurrency_limit: z + .int() + .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) + .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) + .optional(), + concurrency_action: z.int().optional(), + token_quota: z.string().optional(), + token_action: z.int().optional(), +}) + export const zUpdateUserReply = z.object({ account: zAccountDetail.optional(), }) @@ -1430,27 +963,6 @@ export const zUpdateWorkspaceReq = z.object({ status: z.string().optional(), }) -export const zWebAppAccessRow = z.object({ - environment: zConsoleEnvironment.optional(), - url: z.string().optional(), -}) - -export const zAccessChannels = z.object({ - enabled: z.boolean().optional(), - webappRows: z.array(zWebAppAccessRow).optional(), - cli: zCliAccess.optional(), -}) - -export const zGetAppInstanceAccessReply = z.object({ - permissions: z.array(zEnvironmentAccessRow).optional(), - accessChannels: zAccessChannels.optional(), - developerApi: zDeveloperApiAccess.optional(), -}) - -export const zUpdateAccessChannelsReply = z.object({ - accessChannels: zAccessChannels.optional(), -}) - export const zWebAppAuthInfo = z.object({ allowSso: z.boolean().optional(), allowEmailCodeLogin: z.boolean().optional(), @@ -1572,27 +1084,11 @@ export const zPagination = z.object({ .optional(), }) -export const zListAppInstancesReply = z.object({ - filters: z.array(zEnvironmentFilter).optional(), - data: z.array(zAppInstanceCard).optional(), - pagination: zPagination.optional(), -}) - -export const zListEnvironmentsReply = z.object({ - data: z.array(zEnvironment).optional(), - pagination: zPagination.optional(), -}) - export const zListMembersReply = z.object({ data: z.array(zAccountDetail).optional(), pagination: zPagination.optional(), }) -export const zListReleasesReply = z.object({ - data: z.array(zReleaseRow).optional(), - pagination: zPagination.optional(), -}) - export const zListSecretKeysReply = z.object({ data: z.array(zSecretKey).optional(), pagination: zPagination.optional(), @@ -1608,271 +1104,6 @@ export const zListWorkspacesReply = z.object({ pagination: zPagination.optional(), }) -export const zEnterpriseAppDeployConsoleListAppInstancesQuery = z.object({ - environmentId: z.string().optional(), - notDeployed: z.boolean().optional(), - query: z.string().optional(), - pageNumber: z - .int() - .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) - .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) - .optional(), - resultsPerPage: z - .int() - .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) - .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) - .optional(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleListAppInstancesResponse = zListAppInstancesReply - -export const zEnterpriseAppDeployConsoleCreateAppInstanceBody = zCreateAppInstanceReq - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleCreateAppInstanceResponse = zCreateAppInstanceReply - -export const zEnterpriseAppDeployConsoleDeleteAppInstancePath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleDeleteAppInstanceResponse = zDeleteAppInstanceReply - -export const zEnterpriseAppDeployConsoleUpdateAppInstanceBody = zUpdateAppInstanceReq - -export const zEnterpriseAppDeployConsoleUpdateAppInstancePath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleUpdateAppInstanceResponse = zUpdateAppInstanceReply - -export const zEnterpriseAppDeployConsoleGetAppInstanceAccessPath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleGetAppInstanceAccessResponse = zGetAppInstanceAccessReply - -export const zEnterpriseAppDeployConsoleUpdateAccessChannelsBody = zUpdateAccessChannelsReq - -export const zEnterpriseAppDeployConsoleUpdateAccessChannelsPath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleUpdateAccessChannelsResponse = zUpdateAccessChannelsReply - -export const zEnterpriseAppDeployConsoleSearchAccessSubjectsPath = z.object({ - appInstanceId: z.string(), -}) - -export const zEnterpriseAppDeployConsoleSearchAccessSubjectsQuery = z.object({ - keyword: z.string().optional(), - subjectTypes: z.array(z.string()).optional(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleSearchAccessSubjectsResponse = zSearchAccessSubjectsReply - -export const zEnterpriseAppDeployConsoleCreateDeveloperApiKeyBody = zCreateDeveloperApiKeyReq - -export const zEnterpriseAppDeployConsoleCreateDeveloperApiKeyPath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleCreateDeveloperApiKeyResponse = zCreateDeveloperApiKeyReply - -export const zEnterpriseAppDeployConsoleDeleteDeveloperApiKeyPath = z.object({ - appInstanceId: z.string(), - apiKeyId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleDeleteDeveloperApiKeyResponse = zDeleteDeveloperApiKeyReply - -export const zEnterpriseAppDeployConsoleListDeploymentBindingOptionsPath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleListDeploymentBindingOptionsResponse - = zListDeploymentBindingOptionsReply - -export const zEnterpriseAppDeployConsoleCreateDeploymentBody = zCreateDeploymentReq - -export const zEnterpriseAppDeployConsoleCreateDeploymentPath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleCreateDeploymentResponse = zCreateDeploymentReply - -export const zEnterpriseAppDeployConsoleUpdateDeveloperApiBody = zUpdateDeveloperApiReq - -export const zEnterpriseAppDeployConsoleUpdateDeveloperApiPath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleUpdateDeveloperApiResponse = zUpdateDeveloperApiReply - -export const zEnterpriseAppDeployConsoleGetEnvironmentAccessPolicyPath = z.object({ - appInstanceId: z.string(), - environmentId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleGetEnvironmentAccessPolicyResponse - = zGetEnvironmentAccessPolicyReply - -export const zEnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyBody - = zUpdateEnvironmentAccessPolicyReq - -export const zEnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyPath = z.object({ - appInstanceId: z.string(), - environmentId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleUpdateEnvironmentAccessPolicyResponse - = zUpdateEnvironmentAccessPolicyReply - -export const zEnterpriseAppDeployConsoleGetAppInstanceOverviewPath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleGetAppInstanceOverviewResponse - = zGetAppInstanceOverviewReply - -export const zEnterpriseAppDeployConsoleListReleasesPath = z.object({ - appInstanceId: z.string(), -}) - -export const zEnterpriseAppDeployConsoleListReleasesQuery = z.object({ - pageNumber: z - .int() - .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) - .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) - .optional(), - resultsPerPage: z - .int() - .min(-2147483648, { error: 'Invalid value: Expected int32 to be >= -2147483648' }) - .max(2147483647, { error: 'Invalid value: Expected int32 to be <= 2147483647' }) - .optional(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleListReleasesResponse = zListReleasesReply - -export const zEnterpriseAppDeployConsoleCreateReleaseBody = zCreateReleaseReq - -export const zEnterpriseAppDeployConsoleCreateReleasePath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleCreateReleaseResponse = zCreateReleaseReply - -export const zEnterpriseAppDeployConsolePreviewReleaseBody = zPreviewReleaseReq - -export const zEnterpriseAppDeployConsolePreviewReleasePath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsolePreviewReleaseResponse = zPreviewReleaseReply - -export const zEnterpriseAppDeployConsoleListRuntimeInstancesPath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleListRuntimeInstancesResponse = zListRuntimeInstancesReply - -export const zEnterpriseAppDeployConsoleCancelRuntimeDeploymentBody = zCancelRuntimeDeploymentReq - -export const zEnterpriseAppDeployConsoleCancelRuntimeDeploymentPath = z.object({ - appInstanceId: z.string(), - runtimeInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleCancelRuntimeDeploymentResponse - = zCancelRuntimeDeploymentReply - -export const zEnterpriseAppDeployConsoleUndeployRuntimeInstanceBody = zUndeployRuntimeInstanceReq - -export const zEnterpriseAppDeployConsoleUndeployRuntimeInstancePath = z.object({ - appInstanceId: z.string(), - runtimeInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleUndeployRuntimeInstanceResponse - = zUndeployRuntimeInstanceReply - -export const zEnterpriseAppDeployConsoleGetAppInstanceSettingsPath = z.object({ - appInstanceId: z.string(), -}) - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleGetAppInstanceSettingsResponse - = zGetAppInstanceSettingsReply - -/** - * OK - */ -export const zEnterpriseAppDeployConsoleListDeploymentEnvironmentOptionsResponse - = zListDeploymentEnvironmentOptionsReply - /** * OK */ diff --git a/packages/dify-ui/AGENTS.md b/packages/dify-ui/AGENTS.md index bdc2160702..6eadd200f0 100644 --- a/packages/dify-ui/AGENTS.md +++ b/packages/dify-ui/AGENTS.md @@ -9,6 +9,7 @@ Shared design tokens, the `cn()` utility, CSS-first Tailwind styles, and headles - No imports from `web/`. No dependencies on next / i18next / ky / jotai / zustand. - One component per folder: `src//index.tsx`, optional `index.stories.tsx` and `__tests__/index.spec.tsx`. Add a matching `./` subpath to `package.json#exports`. - Props pattern: `Omit & VariantProps & { /* custom */ }`. +- Use plain `Omit<...>` only for non-union Base UI props. When a prop changes the valid shape of related props (for example `value` / `defaultValue`, `multiple` / `value`, or `clearable` / `onChange`), model that relationship with an explicit discriminated union or a distributive helper instead of flattening the props. - When a component accepts a prop typed from a shared internal module, `export type` it from that component so consumers import it from the component subpath. ## Overlay Primitive Selection: Tooltip vs PreviewCard vs Popover @@ -75,7 +76,7 @@ Composition rules: - Keep Base UI primitive semantics visible in the public API. Export compound parts such as `ComboboxInputGroup`, `ComboboxInput`, `ComboboxContent`, `ComboboxList`, `ComboboxItem`, and `ComboboxItemIndicator` instead of wrapping them into one business component. - For `Combobox` multiple selection, follow the official chips pattern: `ComboboxInputGroup` contains `ComboboxChips`, `ComboboxValue` renders `ComboboxChip` items, and `ComboboxInput` remains inside the chips row. Chips should wrap and let the input group grow vertically instead of forcing horizontal overflow. -- Content primitives must own their Base UI `Portal` and use `z-1002` on `Positioner`, matching the overlay contract in `README.md`. +- Content primitives must own their Base UI `Portal` and use `z-50` on `Positioner`, matching the overlay contract in `README.md`. Toast owns `z-60`. - Use `w-(--anchor-width)` with viewport-aware max-width for `Autocomplete` and `Combobox` popups. Do not add `min-w-(--anchor-width)` when it would defeat available-width clamping. [Autocomplete docs]: https://base-ui.com/react/components/autocomplete.md#usage-guidelines diff --git a/packages/dify-ui/README.md b/packages/dify-ui/README.md index c78faede89..010fb3e56d 100644 --- a/packages/dify-ui/README.md +++ b/packages/dify-ui/README.md @@ -84,18 +84,18 @@ Equivalent: any root element with `isolation: isolate` in CSS. Without it, overl Every overlay primitive uses a single, shared z-index. Do **not** override it at call sites. -| Layer | z-index | Where | -| ------------------------------------------------------------------------------------------------------------------- | -------- | -------------------------------------------------------------------------- | -| Overlays (Dialog, AlertDialog, Autocomplete, Combobox, Drawer, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-1002` | Positioner / Backdrop | -| Toast viewport | `z-1003` | One layer above overlays so notifications are never hidden under a dialog. | +| Layer | z-index | Where | +| ------------------------------------------------------------------------------------------------------------------- | ------- | -------------------------------------------------------------------------- | +| Overlays (Dialog, AlertDialog, Autocomplete, Combobox, Drawer, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-50` | Positioner / Backdrop | +| Toast viewport | `z-60` | One layer above overlays so notifications are never hidden under a dialog. | -Rationale: during Dify's migration from legacy `base/modal` / `base/dialog` / `base/drawer` / `base/drawer-plus` overlays to this package, new and old overlays coexist in the DOM. `z-1002` sits above any common legacy layer, eliminating per-call-site z-index hacks. Among themselves, new primitives share the same z-index and **rely on DOM order** for stacking — the portal mounted later wins. +Rationale: Dify UI owns the normal application overlay layer. Overlay primitives share `z-50` and **rely on DOM order** for stacking — the portal mounted later wins. Toast owns `z-60` so notifications remain visible above dialogs, popovers, and other portalled surfaces without falling back to `z-9999`. -See `[web/docs/overlay-migration.md](../../web/docs/overlay-migration.md)` for the Dify-web migration history. Once the legacy overlays are gone, the values in this table can drop back to `z-50` / `z-51`. +See `[web/docs/overlay.md](../../web/docs/overlay.md)` for the web app overlay best practices. ### Rules -- Never add `z-1003` / `z-9999` / etc. overrides on primitives from this package. If something is getting clipped, the **parent** overlay (typically a legacy one) is the problem and should be migrated. +- Never add ad hoc `z-*` overrides on primitives from this package. If something is getting clipped, fix the parent overlay structure instead of raising the child primitive. - Never create an extra manual portal on top of our primitives — use the exported content / portal parts such as `DialogContent`, `PopoverContent`, and `DrawerPortal`. Base UI handles focus management, scroll-locking, and dismissal. - When a primitive needs additional presentation chrome (e.g. a custom backdrop), add it **inside** the exported component, not at call sites. diff --git a/packages/dify-ui/package.json b/packages/dify-ui/package.json index 894e92bfd6..96c512f89c 100644 --- a/packages/dify-ui/package.json +++ b/packages/dify-ui/package.json @@ -77,6 +77,14 @@ "types": "./src/switch/index.tsx", "import": "./src/switch/index.tsx" }, + "./tabs": { + "types": "./src/tabs/index.tsx", + "import": "./src/tabs/index.tsx" + }, + "./toggle-group": { + "types": "./src/toggle-group/index.tsx", + "import": "./src/toggle-group/index.tsx" + }, "./toast": { "types": "./src/toast/index.tsx", "import": "./src/toast/index.tsx" diff --git a/packages/dify-ui/src/alert-dialog/index.tsx b/packages/dify-ui/src/alert-dialog/index.tsx index 7b432c87dc..81299ef932 100644 --- a/packages/dify-ui/src/alert-dialog/index.tsx +++ b/packages/dify-ui/src/alert-dialog/index.tsx @@ -29,14 +29,14 @@ export function AlertDialogContent({ { await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-side', 'bottom') await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-align', 'start') - await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveClass('z-1002') + await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveClass('z-50') await expect.element(screen.getByRole('dialog', { name: 'autocomplete popup' })).toHaveClass('rounded-xl') await expect.element(screen.getByRole('dialog', { name: 'autocomplete popup' })).toHaveClass('w-(--anchor-width)') await expect.element(screen.getByRole('listbox', { name: 'autocomplete list' })).toHaveClass('scroll-py-1') diff --git a/packages/dify-ui/src/autocomplete/index.tsx b/packages/dify-ui/src/autocomplete/index.tsx index 16c4b19673..4c8893b376 100644 --- a/packages/dify-ui/src/autocomplete/index.tsx +++ b/packages/dify-ui/src/autocomplete/index.tsx @@ -261,7 +261,7 @@ export function AutocompleteContent({ align={align} sideOffset={sideOffset} alignOffset={alignOffset} - className={cn('z-1002 outline-hidden', className)} + className={cn('z-50 outline-hidden', className)} {...positionerProps} > { await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-side', 'bottom') await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-align', 'start') - await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveClass('z-1002') + await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveClass('z-50') await expect.element(screen.getByRole('dialog', { name: 'combobox popup' })).toHaveClass('rounded-xl') await expect.element(screen.getByRole('dialog', { name: 'combobox popup' })).toHaveClass('w-(--anchor-width)') await expect.element(screen.getByRole('listbox', { name: 'combobox list' })).toHaveClass('scroll-py-1') diff --git a/packages/dify-ui/src/combobox/index.tsx b/packages/dify-ui/src/combobox/index.tsx index c4f03241f6..eb43b911c7 100644 --- a/packages/dify-ui/src/combobox/index.tsx +++ b/packages/dify-ui/src/combobox/index.tsx @@ -323,7 +323,7 @@ export function ComboboxContent({ align={align} sideOffset={sideOffset} alignOffset={alignOffset} - className={cn('z-1002 outline-hidden', className)} + className={cn('z-50 outline-hidden', className)} {...positionerProps} > { expect(screen.container).not.toContainElement(dialog) await expect.element(dialog).toHaveTextContent('Workspace controls') await expect.element(screen.getByText('Configure the current workspace.')).toBeInTheDocument() - await expect.element(screen.getByTestId('drawer-backdrop')).toHaveClass('z-1002') + await expect.element(screen.getByTestId('drawer-backdrop')).toHaveClass('z-50') asHTMLElement(screen.getByRole('button', { name: 'Close drawer' }).element()).click() diff --git a/packages/dify-ui/src/drawer/index.tsx b/packages/dify-ui/src/drawer/index.tsx index c63bc8174e..a2ad6dcdaf 100644 --- a/packages/dify-ui/src/drawer/index.tsx +++ b/packages/dify-ui/src/drawer/index.tsx @@ -32,7 +32,7 @@ export function DrawerBackdrop({ return ( ) @@ -60,7 +60,7 @@ export function DrawerPopup({ return ( element as HTMLElement + +describe('Tabs wrappers', () => { + it('renders Base UI tabs with accessible roles', async () => { + const screen = await render( + + + JavaScript + Python + + JS panel + Python panel + , + ) + + await expect.element(screen.getByRole('tablist')).toBeInTheDocument() + await expect.element(screen.getByRole('tab', { name: 'JavaScript' })).toHaveAttribute('aria-selected', 'true') + await expect.element(screen.getByRole('tab', { name: 'Python' })).toHaveAttribute('aria-selected', 'false') + await expect.element(screen.getByText('JS panel')).toBeInTheDocument() + }) + + it('keeps tabs styling minimal by default', async () => { + const screen = await render( + + + First + Second + + , + ) + + await expect.element(screen.getByRole('tablist')).toHaveClass( + 'flex', + ) + await expect.element(screen.getByRole('tab', { name: 'First' })).toHaveClass( + 'touch-manipulation', + 'focus-visible:outline-hidden', + ) + }) + + it('calls onValueChange while leaving controlled value to the caller', async () => { + const onValueChange = vi.fn() + const screen = await render( + + + JavaScript + Python + + , + ) + + asHTMLElement(screen.getByRole('tab', { name: 'Python' }).element()).click() + + expect(onValueChange).toHaveBeenCalledWith('py', expect.anything()) + await expect.element(screen.getByRole('tab', { name: 'JavaScript' })).toHaveAttribute('aria-selected', 'true') + }) + + it('forwards className to composable parts', async () => { + const screen = await render( + + + First + + Panel + , + ) + + await expect.element(screen.getByRole('tablist')).toHaveClass('custom-list') + await expect.element(screen.getByRole('tab', { name: 'First' })).toHaveClass('custom-tab') + expect(screen.getByText('Panel').element()).toHaveClass('custom-panel') + }) +}) diff --git a/packages/dify-ui/src/tabs/index.stories.tsx b/packages/dify-ui/src/tabs/index.stories.tsx new file mode 100644 index 0000000000..dd1e79a1ce --- /dev/null +++ b/packages/dify-ui/src/tabs/index.stories.tsx @@ -0,0 +1,51 @@ +import type { Meta, StoryObj } from '@storybook/react-vite' +import { + Tabs, + TabsList, + TabsPanel, + TabsTab, +} from '.' + +const meta = { + title: 'Base/UI/Tabs', + component: Tabs, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Composable tabs built on Base UI. Use this when a tab controls a corresponding tab panel.', + }, + }, + }, + tags: ['autodocs'], +} satisfies Meta + +export default meta +type Story = StoryObj + +export const Basic: Story = { + render: () => ( + + + + Overview + + + Activity + + + + Overview panel + + + Activity panel + + + ), +} diff --git a/packages/dify-ui/src/tabs/index.tsx b/packages/dify-ui/src/tabs/index.tsx new file mode 100644 index 0000000000..ddc5891b89 --- /dev/null +++ b/packages/dify-ui/src/tabs/index.tsx @@ -0,0 +1,59 @@ +'use client' + +import type { Tabs as BaseTabsNS } from '@base-ui/react/tabs' +import { Tabs as BaseTabs } from '@base-ui/react/tabs' +import { cn } from '../cn' + +export type TabsProps = BaseTabsNS.Root.Props + +export const Tabs = BaseTabs.Root + +export type TabsListProps = Omit & { + className?: string +} + +export function TabsList({ + className, + ...props +}: TabsListProps) { + return ( + + ) +} + +export type TabsTabProps = Omit & { + className?: string +} + +export function TabsTab({ + className, + ...props +}: TabsTabProps) { + return ( + + ) +} + +export type TabsPanelProps = Omit & { + className?: string +} + +export function TabsPanel({ + className, + ...props +}: TabsPanelProps) { + return ( + + ) +} + +export const TabsIndicator = BaseTabs.Indicator diff --git a/packages/dify-ui/src/toast/__tests__/index.spec.tsx b/packages/dify-ui/src/toast/__tests__/index.spec.tsx index e02f6828ac..68ba746f4f 100644 --- a/packages/dify-ui/src/toast/__tests__/index.spec.tsx +++ b/packages/dify-ui/src/toast/__tests__/index.spec.tsx @@ -39,7 +39,7 @@ describe('@langgenius/dify-ui/toast', () => { await expect.element(screen.getByText('Saved')).toBeInTheDocument() await expect.element(screen.getByText('Your changes are available now.')).toBeInTheDocument() await expect.element(screen.getByRole('region', { name: 'Notifications' })).toHaveAttribute('aria-live', 'polite') - await expect.element(screen.getByRole('region', { name: 'Notifications' })).toHaveClass('z-1003') + await expect.element(screen.getByRole('region', { name: 'Notifications' })).toHaveClass('z-60') expect(screen.getByRole('region', { name: 'Notifications' }).element().firstElementChild).toHaveClass('top-4') expect(screen.getByRole('dialog').element()).not.toHaveClass('outline-hidden') expect(document.body.querySelector('[aria-hidden="true"].i-ri-checkbox-circle-fill')).toBeInTheDocument() diff --git a/packages/dify-ui/src/toast/index.tsx b/packages/dify-ui/src/toast/index.tsx index a479621563..7d4e867faf 100644 --- a/packages/dify-ui/src/toast/index.tsx +++ b/packages/dify-ui/src/toast/index.tsx @@ -222,7 +222,7 @@ function ToastViewport() {
element as HTMLElement + +describe('ToggleGroup wrappers', () => { + it('renders a segmented control with Base UI pressed state', async () => { + const screen = await render( + + One + Two + , + ) + + await expect.element(screen.getByRole('group')).toHaveClass( + 'bg-components-segmented-control-bg-normal', + 'p-0.5', + 'rounded-[10px]', + ) + await expect.element(screen.getByRole('button', { name: 'One' })).toHaveAttribute('aria-pressed', 'true') + await expect.element(screen.getByRole('button', { name: 'One' })).toHaveClass( + 'data-pressed:bg-components-segmented-control-item-active-bg', + 'data-pressed:text-text-accent-light-mode-only', + ) + }) + + it('uses single selection by default', async () => { + const screen = await render( + + One + Two + , + ) + + asHTMLElement(screen.getByRole('button', { name: 'Two' }).element()).click() + + await expect.element(screen.getByRole('button', { name: 'One' })).toHaveAttribute('aria-pressed', 'false') + await expect.element(screen.getByRole('button', { name: 'Two' })).toHaveAttribute('aria-pressed', 'true') + }) + + it('calls onValueChange while leaving controlled value to the caller', async () => { + const onValueChange = vi.fn() + const screen = await render( + + One + Two + , + ) + + asHTMLElement(screen.getByRole('button', { name: 'Two' }).element()).click() + + expect(onValueChange).toHaveBeenCalledWith(['two'], expect.anything()) + await expect.element(screen.getByRole('button', { name: 'One' })).toHaveAttribute('aria-pressed', 'true') + }) + + it('forwards disabled and className to composable parts', async () => { + const screen = await render( + + One + + Two + , + ) + + await expect.element(screen.getByRole('group')).toHaveClass('custom-group') + await expect.element(screen.getByRole('button', { name: 'One' })).toHaveClass('custom-item') + await expect.element(screen.getByRole('button', { name: 'Two' })).toBeDisabled() + await expect.element(screen.getByTestId('divider')).toHaveClass('custom-divider') + }) +}) diff --git a/packages/dify-ui/src/toggle-group/index.stories.tsx b/packages/dify-ui/src/toggle-group/index.stories.tsx new file mode 100644 index 0000000000..960957b7ab --- /dev/null +++ b/packages/dify-ui/src/toggle-group/index.stories.tsx @@ -0,0 +1,177 @@ +import type { Meta, StoryObj } from '@storybook/react-vite' +import type { ReactNode } from 'react' +import { + ToggleGroup, + ToggleGroupDivider, + ToggleGroupItem, +} from '.' + +const meta = { + title: 'Base/UI/ToggleGroup', + component: ToggleGroup, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Segmented control built on Base UI ToggleGroup and Toggle. Use this for mode, filter, and view selection that does not need tabpanel semantics.', + }, + }, + }, + tags: ['autodocs'], +} satisfies Meta + +export default meta +type Story = StoryObj + +type SegmentedControlProps = { + defaultValue: string + values: string[] + iconOnly?: boolean + noPadding?: boolean +} + +const Icon = () => ( +