diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 21c46d75bc9..e0a1d4dccae 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -102,11 +102,11 @@ describe('ComponentName', () => { }) }) - // Props tests (REQUIRED) + // Props tests (REQUIRED when props change observable behavior) describe('Props', () => { - it('should apply custom className', () => { - render() - expect(screen.getByRole('button')).toHaveClass('custom') + it('should disable the action when disabled', () => { + render() + expect(screen.getByRole('button')).toBeDisabled() }) }) @@ -220,6 +220,7 @@ Every test should clearly separate: ### 2. Black-Box Testing - Test observable behavior, not implementation details +- Test product contracts, not cosmetic implementation. Do not add or expand unit tests only to lock pure style classes, spacing, colors, backgrounds, or layout micro-adjustments. Cover visual-only fixes with browser/manual verification, screenshots, or E2E/visual checks when risk justifies it. Add unit tests only when the change affects user-observable behavior, accessibility semantics, state, data flow, routing, or a stable component API contract. - Use semantic queries (`getByRole` with accessible `name`, `getByLabelText`, `getByPlaceholderText`, `getByText`, and scoped `within(...)`) - Treat `getByTestId` as a last resort. If a control cannot be found by role/name, label, landmark, or dialog scope, fix the component accessibility first instead of adding or relying on `data-testid`. - Remove production `data-testid` attributes when semantic selectors can cover the behavior. Keep them only for non-visual mocked boundaries, editor/browser shims such as Monaco, canvas/chart output, or third-party widgets with no accessible DOM in the test environment. @@ -273,7 +274,7 @@ it('should disable input when isReadOnly is true') ### Always Required (All Components) 1. **Rendering**: Component renders without crashing -1. **Props**: Required props, optional props, default values +1. **Props**: Required props, optional props, default values that change observable behavior. Do not test pass-through styling props such as `className` unless they are an explicit, stable component API whose absence would break a real integration contract. 1. **Edge Cases**: null, undefined, empty values, boundary conditions ### Conditional (When Present) diff --git a/.agents/skills/how-to-write-component/SKILL.md b/.agents/skills/how-to-write-component/SKILL.md index 8a480c8fd09..4a5a73088ec 100644 --- a/.agents/skills/how-to-write-component/SKILL.md +++ b/.agents/skills/how-to-write-component/SKILL.md @@ -1,65 +1,59 @@ --- name: how-to-write-component -description: React/TypeScript component style guide. Use when writing, refactoring, or reviewing React components, especially around abstraction choices, props typing, state boundaries, shared local state with Jotai atoms, API types, query/mutation contracts, navigation, memoization, wrappers, and empty-state handling. +description: Use when writing, refactoring, or reviewing React/TypeScript components in Dify web, especially decisions about component ownership, props/types, URL/query state, Jotai state, async state, generated API contracts, queries/mutations, overlays, effects, navigation, performance, and empty states. --- # How To Write A Component -Use this as the decision guide for React/TypeScript component structure. Existing code is reference material, not automatic precedent; when it conflicts with these rules, adapt the approach instead of reproducing the violation. +Use this as the component decision guide for Dify web. Existing code is reference material, not automatic precedent; if touched code violates these rules, adapt it and fix equivalent patterns in the same feature branch. + +## First Decisions + +| Question | Default | Promote or extract only when | +| --- | --- | --- | +| Where should code live? | Keep it local to the feature workflow, route, or owner. | Multiple verticals need the same stable primitive. | +| Who owns state, data, and handlers? | The lowest component that uses them. | A parent coordinates shared loading, errors, empty UI, selection, submission, navigation, or one consistent snapshot. | +| Should this become Jotai state? | Keep synchronous UI/form state in component or DOM state. | Siblings need one source of truth, the value drives atoms, or scoped workflow state must survive hidden/unmounted steps. | +| Should URL state enter Jotai? | Let Next.js route params and `nuqs` own URL state and updates. | Query atoms or shared derived atoms need a read-only bridge hydrated at the route/surface boundary. | +| Should this query/mutation become an atom? | Use TanStack Query hooks at the lowest owner. | It reads atom state, feeds derived atoms, or participates in shared Jotai workflow orchestration. | +| Should this be a helper/wrapper? | Prefer direct readable code at the use site. | The name captures a stable domain rule or the wrapper owns real behavior, validation, state, error handling, or semantics. | +| Is an Effect needed? | No. Derive during render or handle the user action in the event handler. | It synchronizes with an external system such as browser APIs, subscriptions, timers, analytics, or imperative DOM/non-React widgets. | ## Core Defaults -- Search before adding UI, hooks, helpers, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit. -- Group code by feature workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them. -- Promote code to shared only when multiple verticals need the same stable primitive. Otherwise keep it local and compose shared primitives inside the owning feature. -- Prefer local code and purpose-named helpers over catch-all utility modules; do not group workflow-specific defaults, validation, payload shaping, or metadata merging in a generic utils file just because they share a DTO. -- Keep source/default selection, validation, and payload shaping close to the workflow that owns the behavior. Do not extract a shared helper just because two flows read the same DTO when their priority order, fallback behavior, or submit semantics differ. -- Prefer direct, readable conditionals at the use site for small branch-specific decisions, especially form source selection and request payload assembly. Extract only when the helper name captures a stable domain rule and removes repeated complexity without hiding flow-specific behavior. -- When fixing an invalid pattern, scan the touched feature or branch for equivalent patterns and fix them together. -- Follow Dify's CSS-first Tailwind v4 contract from `packages/dify-ui/README.md` and `packages/dify-ui/AGENTS.md`. Prefer design-system tokens, utilities, and radius mappings over generic Tailwind guidance. +- Search before adding UI, hooks, helpers, query utilities, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit. +- Follow Dify's CSS-first Tailwind v4 contract from `packages/dify-ui/README.md` and `packages/dify-ui/AGENTS.md`. Prefer design-system tokens, utilities, and radius mappings over generic Tailwind choices. +- Group feature code by workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them. +- Keep source/default selection, validation, dirty checks, and payload shaping close to the workflow that owns submit behavior. Do not hide flow-specific priority order, fallback behavior, or submit semantics in generic utilities. +- Prefer direct conditionals for small branch-specific decisions, especially form source selection and request payload assembly. +- Loading states for page sections, cards, lists, tables, forms, and drawers should be skeletons scoped to the content being loaded. Use spinners only for small inline busy indicators. -## Feature Workflow Layout +## Layout And Ownership -- State-heavy wizards, drawers, modals, and secondary workflows work best as a small feature surface with route/entry files, a single feature-local state file, and feature-local UI. -- Keep `ui/` shallow with owner files that map to the workflow's real composition boundaries and major visual regions. -- Owner files contain the section components, field components, skeletons, and one-off helper components that belong to their visual region. -- Folders represent groups of related files with a shared owner and a stable reason to change together. -- The entry file handles route integration, provider wiring, close behavior, and feature surface mounting. The composition owner handles high-level workflow branching, and the closest visual owner handles section branching. +- State-heavy wizards, drawers, modals, and secondary workflows can be a small feature surface: an entry file, one feature-local state file when Jotai is actually needed, and shallow `ui/` owners that match real visual regions. +- The entry file handles route integration, provider wiring, close behavior, and surface mounting. The composition owner handles high-level workflow branching. The closest visual owner handles section branching. +- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data; TanStack Query deduplicates and shares cache. +- Pass stable domain identity across boundaries. Do not forward derived presentation state when the receiver can derive it from its own data source. +- A component that owns a visual surface should also own data access, loading, empty, and error states for content rendered inside it unless a parent truly coordinates that state. +- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in Query and API flow. +- Do not replace prop drilling with one large view-model hook threaded through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it. +- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child, menu, or row own the action. -## Ownership +## Feature-Scoped Jotai -- Put local state, queries, mutations, handlers, and derived UI data in the lowest component that uses them. Extract a purpose-built owner component only when the logic has no natural home. -- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data. Do not hoist a query only because it is duplicated; TanStack Query handles deduplication and cache sharing. -- Hoist state, queries, or callbacks to a parent only when the parent consumes the data, coordinates shared loading/error/empty UI, needs one consistent snapshot, or owns a workflow spanning children. -- Pass stable domain identity across boundaries; avoid forwarding derived presentation state when the receiver can derive it from its own data source. A component that owns a visual surface should also own the data access, loading, empty, and error states for content rendered inside it unless a parent truly coordinates that state. -- Loading states for visual surfaces should use skeleton placeholders scoped to the content that is actually loading, with shape, density, and dimensions close to the final UI. Avoid generic loading text or centered spinners for page sections, cards, lists, tables, forms, and drawers; reserve spinners for small inline busy indicators such as an in-progress status icon. -- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow. -- Do not replace prop drilling with one top-level hook that returns a large view model and then thread that object through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it, or use feature-scoped Jotai atoms for simple shared form/UI state when siblings need the same source of truth. -- When using feature-scoped Jotai state for a form, drawer, or other secondary surface, scope the store to that surface instance when stale cross-instance state is possible. Initialize stable config at the owning boundary, then let descendants read only the atoms or purpose-named hooks they actually need. -- For Jotai-backed surfaces, put shared query atoms, mutation atoms, derived state, and write actions in the feature state file when they coordinate multiple descendants. The lowest-owner rule still applies to independent visual surfaces that do not participate in shared state. -- For repeated row/menu action surfaces that need reset, hydrate the stable identity at the surface entry and scope only the primitives that truly need per-instance reset, such as open flags, drafts, or selected local options. -- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action. -- Prefer uncontrolled DOM state and CSS variables before adding controlled props. - -## Feature-Scoped Jotai State - -- A module's feature-local state lives in one state file for Jotai-backed features: primitive atoms, query atoms, derived atoms, write-only action atoms, mutation atoms, submission orchestration, provider exports, and optional scope configuration. -- Keep state local when one component owns it, even inside Jotai-backed features. Dialog open flags, menu/popover visibility, confirmation visibility, form/input drafts, row-local pending flags, and in-flight refs usually belong in component state. -- Promote UI state to an atom only when siblings need the same source of truth, the value drives a query or mutation atom, a parent workflow coordinates the state, or the state intentionally persists across hidden or unmounted descendants within a scoped surface. -- Reflect atom-backed surface-wide locks or invariants in every affected trigger. If only one row, menu, or dialog should be disabled, keep the pending or lock state local to that row, menu, or dialog. -- Atom order in the state file follows the dependency graph: types/constants, editable primitives, query atoms, query-data derived atoms, readiness/business derived atoms, write actions, mutation atoms, submission orchestration, provider exports. -- Derived atom names read as business facts. Write atom names read as user or workflow commands. -- UI components read and write the exact atom they use with `useAtomValue` or `useSetAtom`. Repeated workflow semantics live in named derived atoms or write atoms. -- Non-query derived atoms return a narrow value with a clear domain name; avoid pass-through aliases or bundling unrelated UI facts. Query atoms expose the TanStack Query result object so loading, error, fetch, and pagination state stay attached to the query contract. -- Write-only atoms own synchronous state transitions that update multiple primitives, reset dependent state, or advance the workflow. Async work with loading, error, caching, retry, or stale-result concerns should be modeled as query or mutation atoms, with write atoms only changing the inputs that drive them. -- Avoid feature hooks that aggregate form values, query results, derived state, and commands for sibling components. Prefer named derived atoms and write atoms so UI components read the exact shared fact or command they need. -- When a form library owns validation, keep submit orchestration in feature state when post-submit result or error state is shared by the surface. Avoid duplicating validation gates or request shaping in UI hooks. -- `jotai-tanstack-query` atoms use the same QueryClient as the React Query provider. Query atoms belong in feature state when atoms are the feature's local state surface. -- Jotai scope is an optional instance-isolation tool for secondary surfaces with independent local state. Query and mutation atoms keep shared cache behavior through the shared QueryClient. -- Do not put `atomWithQuery`, `atomWithInfiniteQuery`, `atomWithMutation`, or broad derived orchestration atoms in a `ScopeProvider` just to reset a surface. Scoped derived atoms implicitly scope their dependencies, which can duplicate query client access and break shared invalidation. Leave query/mutation atoms unscoped; let them read scoped primitive inputs. -- Scope providers should list resettable primitive atoms and explicit hydration tuples. If a derived atom must be scoped, confirm that every dependency it implicitly scopes is meant to be private to that surface. -- Keep independent dialog lifecycles separate. Avoid a single discriminated "current action dialog" atom when edit, delete, and other dialogs have their own open state, loading guard, or reset behavior. -- Route-derived stable identities that do not need instance reset or scoped isolation can be hydrated at the route or layout boundary into a feature route atom. Use scoped atoms only when stale cross-instance state or per-surface reset semantics are needed. +- A Jotai-backed feature has one feature-local state file for shared primitive atoms, query atoms, derived atoms, write-only actions, mutation atoms, submission orchestration, provider exports, and optional scope configuration. +- Keep component-owned synchronous UI state local even inside Jotai features: dialog open flags, menus/popovers, confirmations, field drafts, and selected local options usually belong in component state. +- Use uncontrolled `@langgenius/dify-ui/form` and `@langgenius/dify-ui/field` controls for edit/create forms whose fields are read only at submit time. Initialize query-backed defaults with `defaultValue` and keyed remounts. +- Promote form state to atoms only when another component must react to in-progress values, a draft must survive unmount/remount in the scoped workflow, or multiple steps share the same editable draft before submit. +- Treat `useParams`, route args, and `nuqs` query state as framework-owned state. When atom logic needs those values, hydrate primitive atoms at the route or surface boundary, such as with `useHydrateAtoms(..., { dangerouslyForceHydrate: true })`; keep URL updates in the route/query-state APIs instead of write atoms. +- For async work tied to atom state, use `atomWithQuery` or `atomWithMutation`; write atoms should update only the inputs that drive those atoms. This applies to pure frontend async work as well as network requests, so do not hand-roll loading/error/in-flight state with `useState` or `useRef` for atom-orchestrated async behavior. For component-owned remote work, use `useQuery` or `useMutation` directly. +- Row-local async state belongs to the row owner unless it participates in a shared Jotai workflow or needs atom-scoped reset semantics. +- Leave query and mutation atoms unscoped so they keep shared QueryClient cache and invalidation behavior. Scope resettable primitives and explicit hydration tuples; scope a derived atom only when every dependency should be private to that surface. +- For scoped primitives that are always hydrated by `ScopeProvider`, prefer `atomWithLazy(() => { throw new Error(...) })` when consumers should see a non-null type. +- Order state files by dependency graph: types/constants, primitives, query atoms, query-data derived atoms, business/readiness derived atoms, write actions, mutation atoms, submission orchestration, provider exports. +- Name derived atoms as business facts and write atoms as user or workflow commands. Components should read or write the exact atom they need with `useAtomValue` or `useSetAtom`. +- Menu/dialog `open` state usually stays local, but a scoped atom is acceptable when a composed menu plus secondary surface would otherwise pass confusing `open`/`onClose` props through unrelated layers. Scope that primitive with the surface instance so reset behavior stays local. +- Keep independent dialog lifecycles separate. Avoid one discriminated "current action dialog" atom when dialogs have separate open state, loading guards, or reset behavior. ## Components, Props, And Types @@ -68,80 +62,56 @@ Use this as the decision guide for React/TypeScript component structure. Existin - Prefer named exports. Use default exports only where the framework requires them, such as Next.js route files. - Type simple one-off props inline. Use a named `Props` type only when reused, exported, complex, or clearer. - Use API-generated or API-returned types at component boundaries. Keep small UI conversion helpers and one-off UI extensions beside the component that needs them. -- Do not create type aliases that only rename another type. Use an alias only when it encodes a real UI concept, refinement, or reusable local contract. -- Name values by their domain role and backend API contract, and keep that name stable across the call chain, especially persistent IDs and route params. Normalize framework or route params at the boundary. -- Keep fallback and invariant checks at the lowest component that already handles that state; avoid defensive fallbacks that mask impossible states. -- Do not extract fallback helpers whose only behavior is hiding missing display data. The component that renders the surface owns the empty, disabled, hidden, or placeholder state. +- Do not create type aliases that only rename another type. Use aliases only for real UI concepts, refinements, or reusable local contracts. +- Name values by their domain role and backend API contract, especially persistent IDs and route params. Normalize framework or route params at the boundary. +- Put fallback and invariant checks in the lowest component that already handles that state. Do not extract helpers whose only behavior is hiding missing display data. -## Generated API Contracts +## Generated API And Nullable Data - Treat generated contracts as authoritative at API, query, mutation, cache, and service boundaries. For enterprise APIs, use `packages/contracts/generated/enterprise/*`. -- Do not hand-write local request/response/reply/page/cache-data types that mirror generated DTOs. Import or infer the generated type. -- Do not widen generated fields or enums for compatibility. Normalize legacy input at the boundary, then return the generated field type. -- Do not repair generated or API-returned contract fields in components unless the API contract or product requirement says they need normalization. Treat enums, statuses, and presence flags as exact contract values. -- Use generated enum objects and union types directly in props, comparisons, status logic, and i18n keys. Do not add local enum constants or parallel frontend enum/status layers unless they model real product state not represented by the API. Presentation-only tone maps should be keyed by the generated enum. -- Normalize or coerce only at a real boundary, such as user-entered forms, search, URL/query params, file names, DOM IDs, or legacy adapters. Preserve user-entered values when whitespace or formatting can be meaningful. -- Do not coerce nullable or optional API strings to `''` in query, derived model, or payload-building code. Keep `undefined` or `null` until the final boundary that requires a string. -- Do not use `value || undefined` for mutation payload fields where an empty string means "clear this value". Trim or normalize at the form boundary, then preserve `''` when the API contract treats it as an intentional update. -- Local UI models are fine for presentation, form state, select options, or guarded required-field refinements. Name them as UI concepts, not generated DTO mirrors. -- Required-value refinements are allowed only after same-branch filtering or early return. Prefer nullable-tolerant props for render-only data. -- When a component needs a stricter shape than a generated DTO, refine once at the API/query-to-UI boundary into a purpose-named UI type instead of hiding missing fields with generic fallback or coercion helpers. - -## Nullable API Data - -- Prefer nullable-tolerant call boundaries. Pass API-returned types through for render-only rows, and let the component render fallback, disabled state, or nothing. -- Narrow only where a real value is required, such as mutation params, route hrefs, select values, or query input. Build that target model with `flatMap`, a local loop, or an early return so the required value is captured in the same branch. -- If design says a field is the display value, use that field. Only the final component should decide whether a nullable display value renders a placeholder, hides content, or disables an action. -- Do not wrap required arrays or fields in null-fallback helpers. Use empty collection fallbacks only for not-yet-loaded query data or genuinely nullable collections at the owning render boundary. -- Do not drop rows only to satisfy props or React keys; use a stable fallback key when possible. -- Use conditional spreads or explicit pushes for conditional array items instead of `undefined` placeholders followed by a narrowing filter. -- Avoid truthiness type guards, `filter(Boolean)`, `filter(item => item.id)`, and `!` after those filters. -- Use type guards only for meaningful domain or runtime validation, such as enum membership, object shape, or a reusable business invariant. +- Do not hand-write DTO mirrors, widen generated fields/enums, or add parallel frontend enum/status layers unless they model product state not represented by the API. +- Use generated enum objects and union types directly in props, comparisons, status logic, and i18n keys. Presentation-only tone maps should be keyed by generated enums. +- Normalize or coerce only at real boundaries: user-entered forms, search, URL/query params, file names, DOM IDs, or legacy adapters. +- Do not coerce nullable or optional API strings to `''` in query, derived model, or payload-building code. Keep `null` or `undefined` until the final boundary requiring a string. +- Do not use `value || undefined` for mutation fields where `''` means "clear this value". Trim or normalize at the form boundary, then preserve intentional empty strings. +- Prefer nullable-tolerant render props for API-returned rows. Narrow only where a real value is required, such as mutation params, route hrefs, select values, query input, or required React keys. +- Build required values in the same branch that proves them, using `flatMap`, a local loop, or an early return. Avoid truthiness guards, `filter(Boolean)`, `filter(item => item.id)`, and `!` after filters. +- Use conditional spreads or explicit pushes for conditional array items instead of `undefined` placeholders followed by narrowing filters. +- Empty collection fallbacks are for not-yet-loaded query data or genuinely nullable collections at the owning render boundary, not for hiding required API fields. ## Queries And Mutations -- Keep `web/contract/*` as the single source of truth for API shape; follow existing domain/router patterns and the `{ params, query?, body? }` input shape. -- Consume queries directly with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`. -- In `atomWithQuery` and `atomWithInfiniteQuery`, return generated `queryOptions()` or `infiniteOptions()` directly. Pass `enabled`, `retry`, `placeholderData`, `select`, and pagination options into that call instead of spreading generated options into a hand-built object. -- In `atomWithMutation`, return generated `mutationOptions()` directly when using generated clients. Put request shaping and submit orchestration in write atoms; do not rebuild mutation option objects just to pass through the generated mutation function. -- For custom query functions that do not come from generated clients, wrap the options object with TanStack `queryOptions(...)` so query atoms still return a query options contract. -- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename `queryOptions()` or `mutationOptions()`. Extract a small `queryOptions` helper only when repeated call-site options justify it. -- Keep feature hooks for real orchestration, workflow state, or shared domain behavior. -- For TanStack cache data, use generated or query-derived types; do not create local wrappers for `getQueryData` or `getQueriesData`. -- For generated oRPC `queryOptions()` / `infiniteOptions()`, keep returning the generated options directly. When required input is missing, use a whole-input branch such as `input: condition ? validInput : skipToken` together with `enabled: Boolean(condition)` so no request runs and no fake payload is built. -- Do not put `skipToken` inside a nested placeholder payload, such as `{ params: { appInstanceId: skipToken } }`. Do not create hand-written "missing queryOptions" objects or coerce required IDs to `''`. -- Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))`; use oRPC clients as `mutationFn` only for custom flows. -- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`; components may add UI feedback callbacks, but should not own shared invalidation rules. -- Component or atom mutation callbacks can handle local UI feedback such as toasts, closing dialogs, or navigation. They should not replace shared invalidation or add local cache patches for shared server state. +- Keep `web/contract/*` as the API shape source of truth and follow the `{ params, query?, body? }` input shape. +- Consume generated queries with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`. +- Consume owner-local mutations with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))` when pending/error state is not consumed by feature atoms. +- In `atomWithQuery`, `atomWithInfiniteQuery`, and `atomWithMutation`, return generated `queryOptions()`, `infiniteOptions()`, or `mutationOptions()` directly. Pass `enabled`, `retry`, `placeholderData`, `select`, and pagination options into the generated call instead of spreading options into a hand-built object. +- For generated oRPC options with missing required input, branch the whole input with `input: condition ? validInput : skipToken` and `enabled: Boolean(condition)`. Never place `skipToken` inside a nested placeholder payload or coerce required IDs to `''`. +- When prefetch and render use the same request, extract local query options or a query-options atom so `prefetchQuery` and `useQuery`/`atomWithQuery` share the exact options. +- For custom query or mutation functions, wrap options with TanStack `queryOptions(...)` or `mutationOptions(...)`. +- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename generated options. Keep feature hooks for real orchestration, workflow state, or shared domain behavior. +- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`. Component or atom callbacks may handle local toasts, closing dialogs, and navigation, but should not replace shared invalidation or patch shared server state locally. +- For overlays that may open heavier secondary content, prefetch from the trigger/menu open event with `queryClient.prefetchQuery(queryOptions)` when `onOpenChange` is available. Do not mount hidden subscribers just to warm cache. - Do not use deprecated `useInvalid` or `useReset`. - Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required, and wrap awaited calls in `try/catch`. -## Component Boundaries +## Boundaries And Overlays -- Use the first level below a page or tab to organize independent page sections when it adds real structure. This layer is layout/semantic first, not automatically the data owner. -- Treat component names, semantic roles, and user- or design-marked visual regions as boundary constraints. Do not expand a child component's responsibility just because its data is useful nearby; keep adjacent UI as a sibling owner or introduce a correctly named broader owner. -- Split deeper components by the data and state each layer actually needs. Each component should access only necessary data, and ownership should stay at the lowest consumer. +- Use the first level below a page or tab to organize independent page sections when it adds structure. This layer is layout/semantic first, not automatically the data owner. +- Treat component names, semantic roles, and user- or design-marked visual regions as boundary constraints. Keep adjacent UI as a sibling owner or introduce a correctly named broader owner. - Keep cohesive forms, menu bodies, and one-off helpers local unless they need their own state, reuse, or semantic boundary. -- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow. -- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment. -- When a dialog, dropdown, or popover component already accepts controlled `open` state, mount the surface unconditionally unless unmounting is required for performance or reset semantics. Use keyed scope or local state reset for reset behavior instead of `{open && }` wrappers. -- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible. -- Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, children-as-pass-through composition, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook, forwards props, or passes trigger/content through to one child, move the logic into that child or make the wrapper own a real surface. +- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component when hidden content would obscure the parent. +- Preserve composability by separating behavior ownership from placement ownership: an action can own trigger/open/menu content while the caller owns slots, offsets, and alignment. +- When a dialog, dropdown, or popover accepts controlled `open`, mount it unconditionally unless unmounting is required for performance or reset semantics. Use keyed scope or local state reset instead of `{open && }` wrappers. +- When opening a dialog from a menu item, keep the menu and dialog as sibling surfaces. Let the menu command open the dialog, and mount the dialog outside menu popup content. +- For dialogs and alert dialogs, keep the root responsible for `open` wiring and put query/mutation hooks inside the content component when work should mount only after the overlay opens. +- Prefer uncontrolled overlay roots when the library can own open state. Use `onOpenChange` for side effects and CSS/data selectors for open-state styling. +- Avoid wrapper DOM unless it provides layout, semantics, accessibility, state ownership, or library integration. Avoid shallow wrappers, hook-to-props adapters, layout-only render props, children pass-through wrappers, and prop renaming unless they add real behavior or a real boundary. -## You Might Not Need An Effect - -- Use Effects only to synchronize with external systems such as browser APIs, non-React widgets, subscriptions, timers, analytics that must run because the component was shown, or imperative DOM integration. -- Do not use Effects to transform props or state for rendering. Calculate derived values during render, and use `useMemo` only when the calculation is actually expensive. -- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known. -- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render. -- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary. -- For forms initialized from query data, prefer keyed remounts or surface-entry hydration of form/field atoms over an Effect that copies query data into form state. -- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components. -- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow. - -## Navigation And Performance +## Effects, Navigation, And Performance +- Use Effects only to synchronize with external systems. Do not use Effects to transform props/state for rendering, handle user actions, copy state, reset state from props, or fetch data. +- For forms initialized from query data, prefer keyed remounts or surface-entry atom hydration over Effects that copy query data into form state. +- Prefer framework data APIs or TanStack Query for data fetching. - Prefer `Link` for normal navigation. Use router APIs only for command-flow side effects such as mutation success, guarded redirects, or form submission. -- Before reaching for `memo`, first try moving changing state down to the smallest component that actually uses it so unrelated sibling trees stay untouched. -- If changing state must wrap other content, lift the unchanged content up and pass it as `children` so the stateful wrapper can update without React visiting that subtree. +- Before using `memo`, move changing state down to the smallest component that uses it. If state must wrap stable content, lift the stable content up and pass it as `children`. - Avoid `memo`, `useMemo`, and `useCallback` unless there is a clear performance reason. diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index e4c686adb31..623703c6d35 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -105,6 +105,14 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: vp run knip + - name: Web dead code check production + if: steps.changed-files.outputs.any_changed == 'true' + run: vp run knip:production + + - name: Web production unused declarations check + if: steps.changed-files.outputs.any_changed == 'true' + run: vp run knip:production-unused-check + ts-common-style: name: TS Common runs-on: depot-ubuntu-24.04 diff --git a/README.md b/README.md index b6cbb0e1261..35bbb897f1a 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly ## Using Dify - **Cloud
** - We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. + We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. If you run into issues with Dify Cloud, [contact our Cloud support team](mailto:cloud@dify.ai?subject=%5BGitHub%5DDify%20Cloud%20Support). - **Self-hosting Dify Community Edition
** Quickly get Dify running in your environment with this [starter guide](#quick-start). diff --git a/api/.env.example b/api/.env.example index 3aa107130f9..48d8707d1ad 100644 --- a/api/.env.example +++ b/api/.env.example @@ -36,6 +36,9 @@ FILES_ACCESS_TIMEOUT=300 # Collaboration mode toggle ENABLE_COLLABORATION_MODE=true +# Learn app feature toggle +ENABLE_LEARN_APP=true + # Access token expiration time in minutes ACCESS_TOKEN_EXPIRE_MINUTES=60 diff --git a/api/commands/account.py b/api/commands/account.py index 0d99ce7a0fa..dfd57d43142 100644 --- a/api/commands/account.py +++ b/api/commands/account.py @@ -25,7 +25,7 @@ def reset_password(email, new_password, password_confirm): return normalized_email = email.strip().lower() - account = AccountService.get_account_by_email_with_case_fallback(email.strip()) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email.strip()) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) @@ -67,7 +67,7 @@ def reset_email(email, new_email, email_confirm): return normalized_new_email = new_email.strip().lower() - account = AccountService.get_account_by_email_with_case_fallback(email.strip()) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email.strip()) if not account: click.echo(click.style(f"Account not found for email: {email}", fg="red")) @@ -133,8 +133,9 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No password=new_password, language=language, create_workspace_required=False, + session=db.session, ) - TenantService.create_owner_tenant_if_not_exist(account, name) + TenantService.create_owner_tenant_if_not_exist(account, name, session=db.session) click.echo( click.style( diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index dc8c840da9c..f664274ba75 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1073,6 +1073,12 @@ class MailConfig(BaseSettings): default=None, ) + +class HomepageConfig(BaseSettings): + """ + Configuration for homepage feature toggles exposed through system features. + """ + ENABLE_TRIAL_APP: bool = Field( description="Enable trial app", default=False, @@ -1083,6 +1089,11 @@ class MailConfig(BaseSettings): default=False, ) + ENABLE_LEARN_APP: bool = Field( + description="Enable Learn App", + default=True, + ) + class RagEtlConfig(BaseSettings): """ @@ -1489,6 +1500,7 @@ class FeatureConfig( EndpointConfig, FileAccessConfig, FileUploadConfig, + HomepageConfig, HttpConfig, InnerAPIConfig, IndexingConfig, diff --git a/api/controllers/console/agent/composer.py b/api/controllers/console/agent/composer.py index 975586c635c..b54cf4b6daf 100644 --- a/api/controllers/console/agent/composer.py +++ b/api/controllers/console/agent/composer.py @@ -28,9 +28,9 @@ from libs.login import login_required from models.model import App, AppMode from services.agent.composer_service import AgentComposerService from services.agent.composer_validator import ComposerConfigValidator -from services.entities.agent_entities import ComposerSavePayload +from services.entities.agent_entities import ComposerSavePayload, WorkflowComposerCopyFromRosterPayload -register_schema_models(console_ns, ComposerSavePayload) +register_schema_models(console_ns, ComposerSavePayload, WorkflowComposerCopyFromRosterPayload) register_response_schema_models( console_ns, AgentAppComposerResponse, @@ -91,6 +91,38 @@ class WorkflowAgentComposerApi(Resource): ) +@console_ns.route("/apps//workflows/draft/nodes//agent-composer/copy-from-roster") +class WorkflowAgentComposerCopyFromRosterApi(Resource): + @console_ns.expect(console_ns.models[WorkflowComposerCopyFromRosterPayload.__name__]) + @console_ns.response( + 200, + "Workflow roster agent copied to inline agent", + console_ns.models[WorkflowAgentComposerResponse.__name__], + ) + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_EDIT) + @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) + @with_current_user_id + @with_current_tenant_id + def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str): + payload = WorkflowComposerCopyFromRosterPayload.model_validate(console_ns.payload or {}) + return dump_response( + WorkflowAgentComposerResponse, + AgentComposerService.copy_workflow_composer_from_roster( + tenant_id=tenant_id, + app_id=app_model.id, + node_id=node_id, + account_id=account_id, + source_agent_id=payload.source_agent_id, + source_snapshot_id=payload.source_snapshot_id, + idempotency_key=payload.idempotency_key, + ), + ) + + @console_ns.route("/apps//workflows/draft/nodes//agent-composer/validate") class WorkflowAgentComposerValidateApi(Resource): @console_ns.expect(console_ns.models[ComposerSavePayload.__name__]) diff --git a/api/controllers/console/app/agent_app_feature.py b/api/controllers/console/app/agent_app_feature.py index d155dae6ac3..358e552beb0 100644 --- a/api/controllers/console/app/agent_app_feature.py +++ b/api/controllers/console/app/agent_app_feature.py @@ -91,7 +91,10 @@ class AgentAppFeatureConfigResource(Resource): args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {}) new_app_model_config = AgentAppFeatureConfigService.update_features( - app_model=app_model, account=current_user, config=args.model_dump(exclude_none=True), session=db.session + app_model=app_model, + account=current_user, + config=args.model_dump(exclude_none=True), + session=db.session, ) app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index edf3a98af8c..48fb4aedc63 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -19,6 +19,7 @@ from controllers.console.wraps import ( rbac_permission_required, setup_required, ) +from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.annotation_fields import ( Annotation, @@ -388,7 +389,9 @@ class AnnotationUpdateDeleteApi(Resource): update_args["answer"] = args.answer if args.question is not None: update_args["question"] = args.question - annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id)) + annotation = AppAnnotationService.update_app_annotation_directly( + update_args, str(app_id), str(annotation_id), db.session + ) return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json") @setup_required @@ -398,7 +401,7 @@ class AnnotationUpdateDeleteApi(Resource): @rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_EDIT) @console_ns.response(204, "Annotation deleted successfully") def delete(self, app_id: UUID, annotation_id: UUID): - AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id)) + AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id), db.session) return "", 204 diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 43b41903f60..b66c97c274c 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -30,6 +30,7 @@ from controllers.console.wraps import ( setup_required, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode @@ -142,6 +143,7 @@ class ChatMessageTextApi(Resource): response = AudioService.transcript_tts( app_model=app_model, + session=db.session, text=payload.text, voice=payload.voice, message_id=payload.message_id, diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 726bd94cd7e..195a41f2888 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -341,8 +341,8 @@ class MessageFeedbackExportApi(Resource): try: export_data = FeedbackService.export_feedbacks( - db.session(), - app_id=app_model.id, + app_model.id, + session=db.session(), from_source=args.from_source, rating=args.rating, has_comment=args.has_comment, diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py index a9bf85ed36c..c70f00dcfa1 100644 --- a/api/controllers/console/app/workflow_comment.py +++ b/api/controllers/console/app/workflow_comment.py @@ -17,6 +17,7 @@ from controllers.console.wraps import ( with_current_tenant_id, with_current_user, ) +from extensions.ext_database import db from fields.base import ResponseModel from fields.member_fields import AccountWithRole from libs.helper import build_avatar_url, dump_response, to_timestamp @@ -489,7 +490,7 @@ class WorkflowCommentMentionUsersApi(Resource): current_tenant = current_user.current_tenant # need the tenant object here if current_tenant is None: raise ValueError("current tenant is required") - members = TenantService.get_tenant_members(current_tenant) + members = TenantService.get_tenant_members(current_tenant, session=db.session) users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) response = WorkflowCommentMentionUsersPayload(users=users) return response.model_dump(mode="json"), 200 diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index c9142d85ede..b6045685b55 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -89,7 +89,9 @@ class ActivateCheckApi(Resource): workspaceId = args.workspace_id token = args.token - invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token) + invitation = RegisterService.get_invitation_with_case_fallback( + workspaceId, args.email, token, session=db.session + ) if invitation: data = invitation.get("data", {}) tenant = invitation.get("tenant", None) @@ -137,7 +139,9 @@ class ActivateApi(Resource): args = ActivatePayload.model_validate(console_ns.payload) normalized_request_email = args.email.lower() if args.email else None - invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token) + invitation = RegisterService.get_invitation_with_case_fallback( + args.workspace_id, args.email, args.token, session=db.session + ) if invitation is None: raise AlreadyActivateError() @@ -184,6 +188,6 @@ class ActivateApi(Resource): account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() - TenantService.switch_tenant(account, tenant.id) + TenantService.switch_tenant(account, tenant.id, session=db.session) return {"result": "success"} diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 1de206c73db..11fab84a831 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models, register_schema_models +from extensions.ext_database import db from fields.base import ResponseModel from libs.login import login_required from services.auth.api_key_auth_service import ApiKeyAuthService @@ -58,7 +59,7 @@ class ApiKeyAuthDataSource(Resource): @account_initialization_required @with_current_tenant_id def get(self, current_tenant_id: str): - data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id) + data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(db.session(), current_tenant_id) if data_source_api_key_bindings: return { "sources": [ @@ -92,7 +93,7 @@ class ApiKeyAuthDataSourceBinding(Resource): data = payload.model_dump() ApiKeyAuthService.validate_api_key_auth_args(data) try: - ApiKeyAuthService.create_provider_auth(current_tenant_id, data) + ApiKeyAuthService.create_provider_auth(db.session(), current_tenant_id, data) except Exception as e: raise ApiKeyAuthFailedError(str(e)) return {"result": "success"}, 200 @@ -109,6 +110,6 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): @with_current_tenant_id def delete(self, current_tenant_id: str, binding_id: UUID): # The role of the current user in the table must be admin or owner - ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id)) + ApiKeyAuthService.delete_provider_auth(db.session(), current_tenant_id, str(binding_id)) return "", 204 diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 912eb26574c..ba4fc1275d9 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -15,6 +15,7 @@ from controllers.console.auth.error import ( InvalidTokenError, PasswordMismatchError, ) +from extensions.ext_database import db from fields.base import ResponseModel from libs.helper import EmailStr, extract_remote_ip from libs.helper import timezone as validate_timezone_string @@ -100,7 +101,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - account = AccountService.get_account_by_email_with_case_fallback(args.email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, args.email) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -175,7 +176,7 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if account: raise EmailAlreadyInUseError() @@ -186,7 +187,7 @@ class EmailRegisterResetApi(Resource): timezone=args.timezone, language=args.language, ) - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(normalized_email) return {"result": "success", "data": token_pair.model_dump()} @@ -205,6 +206,7 @@ class EmailRegisterResetApi(Resource): password=password, interface_language=get_valid_language(language), timezone=timezone, + session=db.session, ) except AccountRegisterError: raise AccountInFreezeError() diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index d82f63c11db..8df9600070c 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -82,7 +82,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - account = AccountService.get_account_by_email_with_case_fallback(args.email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, args.email) token = AccountService.send_reset_password_email( account=account, @@ -180,7 +180,7 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if account: account = db.session.merge(account) @@ -198,10 +198,10 @@ class ForgotPasswordResetApi(Resource): # Create workspace if needed if ( - not TenantService.get_join_tenants(account) + not TenantService.get_join_tenants(account, session=db.session) and FeatureService.get_system_features().is_allow_create_workspace ): - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session) TenantService.create_tenant_member(tenant, account, db.session, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 053f313ba53..81f9ee4bae4 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -119,7 +119,9 @@ class LoginApi(Resource): invite_token = args.invite_token invitation_data: InvitationDetailDict | None = None if invite_token: - invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token) + invitation_data = RegisterService.get_invitation_with_case_fallback( + None, request_email, invite_token, session=db.session + ) if invitation_data is None: invite_token = None @@ -145,7 +147,7 @@ class LoginApi(Resource): _log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS) raise AuthenticationFailedError() from exc # SELF_HOSTED only have one workspace - tenants = TenantService.get_join_tenants(account) + tenants = TenantService.get_join_tenants(account, session=db.session) if len(tenants) == 0: system_features = FeatureService.get_system_features() @@ -157,7 +159,7 @@ class LoginApi(Resource): "data": "workspace not found, please contact system admin to invite you to join in a workspace", } - token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(normalized_email) # Create response with cookies instead of returning tokens in body @@ -291,7 +293,7 @@ class EmailCodeLoginApi(Resource): _log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE) raise AccountInFreezeError() if account: - tenants = TenantService.get_join_tenants(account) + tenants = TenantService.get_join_tenants(account, session=db.session) if not tenants: workspaces = FeatureService.get_system_features().license.workspaces if not workspaces.is_available(): @@ -299,7 +301,7 @@ class EmailCodeLoginApi(Resource): if not FeatureService.get_system_features().is_allow_create_workspace: raise NotAllowedCreateWorkspace() else: - new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session) TenantService.create_tenant_member(new_tenant, account, db.session, role="owner") account.current_tenant = new_tenant tenant_was_created.send(new_tenant) @@ -311,6 +313,7 @@ class EmailCodeLoginApi(Resource): name=user_email, interface_language=get_valid_language(language), timezone=args.timezone, + session=db.session, ) except WorkSpaceNotAllowedCreateError: raise NotAllowedCreateWorkspace() @@ -319,7 +322,7 @@ class EmailCodeLoginApi(Resource): raise AccountInFreezeError() except WorkspacesLimitExceededError: raise WorkspacesLimitExceeded() - token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) + token_pair = AccountService.login(account, session=db.session, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(user_email) # Create response with cookies instead of returning tokens in body @@ -343,7 +346,7 @@ class RefreshTokenApi(Resource): return {"result": "fail", "message": "No refresh token provided"}, 401 try: - new_token_pair = AccountService.refresh_token(refresh_token) + new_token_pair = AccountService.refresh_token(refresh_token, session=db.session) # Create response with new cookies response = make_response({"result": "success"}) @@ -358,22 +361,22 @@ class RefreshTokenApi(Resource): def _get_account_with_case_fallback(email: str): - account = AccountService.get_user_through_email(email) + account = AccountService.get_user_through_email(email, session=db.session) if account or email == email.lower(): return account - return AccountService.get_user_through_email(email.lower()) + return AccountService.get_user_through_email(email.lower(), session=db.session) def _authenticate_account_with_case_fallback( original_email: str, normalized_email: str, password: str, invite_token: str | None ): try: - return AccountService.authenticate(original_email, password, invite_token) + return AccountService.authenticate(original_email, password, invite_token, session=db.session) except services.errors.account.AccountPasswordError: if original_email == normalized_email: raise - return AccountService.authenticate(normalized_email, password, invite_token) + return AccountService.authenticate(normalized_email, password, invite_token, session=db.session) def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None: diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 78d1583fde9..65f3a5addde 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -195,7 +195,7 @@ class OAuthCallback(Resource): db.session.commit() try: - TenantService.create_owner_tenant_if_not_exist(account) + TenantService.create_owner_tenant_if_not_exist(account, session=db.session) except Unauthorized: return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.") except WorkSpaceNotAllowedCreateError: @@ -206,6 +206,7 @@ class OAuthCallback(Resource): token_pair = AccountService.login( account=account, + session=db.session, ip_address=extract_remote_ip(request), ) @@ -224,7 +225,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - account = AccountService.get_account_by_email_with_case_fallback(user_info.email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, user_info.email) return account @@ -240,12 +241,12 @@ def _generate_account( oauth_new_user = False if account: - tenants = TenantService.get_join_tenants(account) + tenants = TenantService.get_join_tenants(account, session=db.session) if not tenants: if not FeatureService.get_system_features().is_allow_create_workspace: raise WorkSpaceNotAllowedCreateError() else: - new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session) TenantService.create_tenant_member(new_tenant, account, db.session, role="owner") account.current_tenant = new_tenant tenant_was_created.send(new_tenant) @@ -272,9 +273,10 @@ def _generate_account( provider=provider, language=interface_language, timezone=timezone, + session=db.session, ) # Link account - AccountService.link_account_integrate(provider, user_info.id, account) + AccountService.link_account_integrate(provider, user_info.id, account, session=db.session) return account, oauth_new_user diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 7195fe066fd..ebb490cd9e8 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -17,6 +17,7 @@ from controllers.console.wraps import ( with_current_tenant_id, with_current_user, ) +from extensions.ext_database import db from fields.dataset_fields import ( DatasetMetadataBuiltInFieldsResponse, DatasetMetadataListResponse, @@ -65,7 +66,9 @@ class DatasetMetadataCreateApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.create_metadata(dataset_id_str, metadata_args, current_user, current_tenant_id) + metadata = MetadataService.create_metadata( + db.session(), dataset_id_str, metadata_args, current_user, current_tenant_id + ) return dump_response(DatasetMetadataResponse, metadata), 201 @setup_required @@ -81,7 +84,7 @@ class DatasetMetadataCreateApi(Resource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - metadata = MetadataService.get_dataset_metadatas(dataset) + metadata = MetadataService.get_dataset_metadatas(db.session(), dataset) return dump_response(DatasetMetadataListResponse, metadata), 200 @@ -108,7 +111,7 @@ class DatasetMetadataApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) metadata = MetadataService.update_metadata_name( - dataset_id_str, metadata_id_str, name, current_user, current_tenant_id + db.session(), dataset_id_str, metadata_id_str, name, current_user, current_tenant_id ) return dump_response(DatasetMetadataResponse, metadata), 200 @@ -127,7 +130,7 @@ class DatasetMetadataApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - MetadataService.delete_metadata(dataset_id_str, metadata_id_str) + MetadataService.delete_metadata(db.session(), dataset_id_str, metadata_id_str) # Frontend callers only await success and invalidate metadata caches; no response body is consumed. return "", 204 @@ -166,9 +169,9 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): match action: case "enable": - MetadataService.enable_built_in_field(dataset) + MetadataService.enable_built_in_field(db.session(), dataset) case "disable": - MetadataService.disable_built_in_field(dataset) + MetadataService.disable_built_in_field(db.session(), dataset) # Frontend callers only await success and invalidate metadata caches; no response body is consumed. return "", 204 @@ -195,7 +198,7 @@ class DocumentMetadataEditApi(Resource): metadata_args = MetadataOperationData.model_validate(console_ns.payload or {}) - MetadataService.update_documents_metadata(dataset, metadata_args, current_user) + MetadataService.update_documents_metadata(db.session(), dataset, metadata_args, current_user) # Frontend callers only await success and invalidate caches; no response body is consumed. return "", 204 diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 756dfe84f6c..c2104ccfc61 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -20,6 +20,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from models.model import InstalledApp from services.audio_service import AudioService @@ -99,7 +100,13 @@ class ChatTextApi(InstalledAppResource): text = payload.text voice = payload.voice - response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) + response = AudioService.transcript_tts( + app_model=app_model, + session=db.session, + text=text, + voice=voice, + message_id=message_id, + ) return response except services.errors.app_model_config.AppModelConfigBrokenError: logger.exception("App model config broken.") diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index c1fa1378ffa..fd5b003b523 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -175,7 +175,7 @@ class InstalledAppsListApi(Resource): if current_user.current_tenant is None: raise ValueError("current_user.current_tenant must not be None") - current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) + current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant, session=db.session) installed_app_list: list[dict[str, Any]] = [] for installed_app, app_model in installed_apps: installed_app_list.append( diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 1c82f6ba9d4..2be550b2f28 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -27,7 +27,11 @@ from controllers.console.wraps import with_current_user from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse -from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse +from fields.message_fields import ( + ExploreMessageInfiniteScrollPagination, + ExploreMessageListItem, + SuggestedQuestionsResponse, +) from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from models import Account @@ -56,7 +60,7 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor register_response_schema_models( console_ns, GeneratedAppResponse, - MessageInfiniteScrollPagination, + ExploreMessageInfiniteScrollPagination, ResultResponse, SuggestedQuestionsResponse, ) @@ -68,7 +72,7 @@ register_response_schema_models( ) class MessageListApi(InstalledAppResource): @console_ns.doc(params=query_params_from_model(MessageListQuery)) - @console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPagination.__name__]) + @console_ns.response(200, "Success", console_ns.models[ExploreMessageInfiniteScrollPagination.__name__]) @with_current_user def get(self, current_user: Account, installed_app: InstalledApp): app_model = installed_app.app @@ -88,9 +92,9 @@ class MessageListApi(InstalledAppResource): str(args.first_id) if args.first_id else None, args.limit, ) - adapter = TypeAdapter(MessageListItem) + adapter = TypeAdapter(ExploreMessageListItem) items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data] - return MessageInfiniteScrollPagination( + return ExploreMessageInfiniteScrollPagination( limit=pagination.limit, has_more=pagination.has_more, data=items, diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index ad98dd303fb..6aef9129780 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -419,7 +419,13 @@ class TrialChatTextApi(TrialAppResource): app_id = app_model.id user_id = current_user.id - response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) + response = AudioService.transcript_tts( + app_model=app_model, + session=db.session, + text=text, + voice=voice, + message_id=message_id, + ) RecommendedAppService.add_trial_app_record(db.session, app_id, user_id) return response except services.errors.app_model_config.AppModelConfigBrokenError: diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index f086bf18622..27f6bcc36dc 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -50,7 +50,7 @@ def get_init_status() -> InitStatusResponse: @only_edition_self_hosted def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse: """Validate initialization password.""" - tenant_count = TenantService.get_tenant_count() + tenant_count = TenantService.get_tenant_count(session=db.session) if tenant_count > 0: raise AlreadySetupError() diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 279e4ec502d..3b5c1bbe18f 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -79,7 +79,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse: if get_setup_status(): raise AlreadySetupError() - tenant_count = TenantService.get_tenant_count() + tenant_count = TenantService.get_tenant_count(session=db.session) if tenant_count > 0: raise AlreadySetupError() @@ -94,6 +94,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse: password=payload.password, ip_address=extract_remote_ip(request), language=payload.language, + session=db.session, ) return SetupResponse(result="success") diff --git a/api/controllers/console/socketio/workflow.py b/api/controllers/console/socketio/workflow.py index b4f03593fd7..99e56df3cb8 100644 --- a/api/controllers/console/socketio/workflow.py +++ b/api/controllers/console/socketio/workflow.py @@ -4,6 +4,7 @@ from typing import cast from flask import Request as FlaskRequest +from extensions.ext_database import db from extensions.ext_socketio import sio from libs.passport import PassportService from libs.token import extract_access_token @@ -43,7 +44,7 @@ def socket_connect(sid, environ, auth): return False with sio.app.app_context(): - user = AccountService.load_logged_in_account(account_id=user_id) + user = AccountService.load_logged_in_account(account_id=user_id, session=db.session) if not user: logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid) return False diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 0ac26168bc5..c13c8aa162f 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -328,7 +328,7 @@ class AccountNameApi(Resource): def post(self, current_user: Account): payload = console_ns.payload or {} args = AccountNamePayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, name=args.name) + updated_account = AccountService.update_account(current_user, session=db.session, name=args.name) return _serialize_account(updated_account) @@ -374,7 +374,7 @@ class AccountAvatarApi(Resource): payload = console_ns.payload or {} args = AccountAvatarPayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, avatar=args.avatar) + updated_account = AccountService.update_account(current_user, session=db.session, avatar=args.avatar) return _serialize_account(updated_account) @@ -391,7 +391,9 @@ class AccountInterfaceLanguageApi(Resource): payload = console_ns.payload or {} args = AccountInterfaceLanguagePayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, interface_language=args.interface_language) + updated_account = AccountService.update_account( + current_user, session=db.session, interface_language=args.interface_language + ) return _serialize_account(updated_account) @@ -408,7 +410,9 @@ class AccountInterfaceThemeApi(Resource): payload = console_ns.payload or {} args = AccountInterfaceThemePayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme) + updated_account = AccountService.update_account( + current_user, session=db.session, interface_theme=args.interface_theme + ) return _serialize_account(updated_account) @@ -425,7 +429,7 @@ class AccountTimezoneApi(Resource): payload = console_ns.payload or {} args = AccountTimezonePayload.model_validate(payload) - updated_account = AccountService.update_account(current_user, timezone=args.timezone) + updated_account = AccountService.update_account(current_user, session=db.session, timezone=args.timezone) return _serialize_account(updated_account) @@ -443,7 +447,8 @@ class AccountPasswordApi(Resource): args = AccountPasswordPayload.model_validate(payload) try: - AccountService.update_account_password(current_user, args.password, args.new_password) + assert args.password is not None + AccountService.update_account_password(current_user, args.password, args.new_password, session=db.session) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() @@ -731,7 +736,7 @@ class ChangeEmailResetApi(Resource): if AccountService.is_account_in_freeze(normalized_new_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(normalized_new_email): + if not AccountService.check_email_unique(normalized_new_email, session=db.session): raise EmailAlreadyInUseError() reset_data = AccountService.get_change_email_data(args.token) @@ -755,7 +760,9 @@ class ChangeEmailResetApi(Resource): # legitimately verified token. AccountService.revoke_change_email_token(args.token) - updated_account = AccountService.update_account_email(current_user, email=normalized_new_email) + updated_account = AccountService.update_account_email( + current_user, email=normalized_new_email, session=db.session + ) AccountService.send_change_email_completed_notify_email( email=normalized_new_email, @@ -775,6 +782,6 @@ class CheckEmailUnique(Resource): normalized_email = args.email.lower() if AccountService.is_account_in_freeze(normalized_email): raise AccountInFreezeError() - if not AccountService.check_email_unique(normalized_email): + if not AccountService.check_email_unique(normalized_email, session=db.session): raise EmailAlreadyInUseError() return {"result": "success"} diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index b3230d77e69..3a2e3c92359 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -131,7 +131,7 @@ def _normalize_invitee_emails(emails: list[str]) -> list[str]: def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int: new_member_count = 0 for email in emails: - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if not account: new_member_count += 1 continue @@ -186,7 +186,7 @@ class MemberListApi(Resource): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") - members = TenantService.get_tenant_members(current_user.current_tenant) + members = TenantService.get_tenant_members(current_user.current_tenant, session=db.session) if dify_config.RBAC_ENABLED: member_ids = [member.id for member in members] member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get( @@ -273,6 +273,7 @@ class MemberInviteEmailApi(Resource): language=interface_language, role=invitee_role, inviter=inviter, + session=db.session, ) encoded_invitee_email = parse.quote(invitee_email) invitation_results.append( @@ -317,7 +318,9 @@ class MemberCancelInviteApi(Resource): abort(404) else: try: - TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) + TenantService.remove_member_from_tenant( + current_user.current_tenant, member, current_user, session=db.session + ) except services.errors.account.CannotOperateSelfError as e: return {"code": "cannot-operate-self", "message": str(e)}, 400 except services.errors.account.NoPermissionError as e: @@ -360,7 +363,9 @@ class MemberUpdateRoleApi(Resource): try: assert member is not None, "Member not found" - TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user) + TenantService.update_member_role( + current_user.current_tenant, member, new_role, current_user, session=db.session + ) except services.errors.account.CannotOperateSelfError as e: return {"code": "cannot-operate-self", "message": str(e)}, 400 except services.errors.account.NoPermissionError as e: @@ -387,7 +392,7 @@ class DatasetOperatorMemberListApi(Resource): def get(self, current_user: Account): if not current_user.current_tenant: raise ValueError("No current tenant") - members = TenantService.get_dataset_operator_members(current_user.current_tenant) + members = TenantService.get_dataset_operator_members(current_user.current_tenant, session=db.session) member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True) response = AccountWithRoleList(accounts=member_models) return response.model_dump(mode="json"), 200 @@ -413,7 +418,7 @@ class SendOwnerTransferEmailApi(Resource): # check if the current user is the owner of the workspace if not current_user.current_tenant: raise ValueError("No current tenant") - if not TenantService.is_owner(current_user, current_user.current_tenant): + if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session): raise NotOwnerError() if args.language is not None and args.language == "zh-Hans": @@ -448,7 +453,7 @@ class OwnerTransferCheckApi(Resource): # check if the current user is the owner of the workspace if not current_user.current_tenant: raise ValueError("No current tenant") - if not TenantService.is_owner(current_user, current_user.current_tenant): + if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session): raise NotOwnerError() user_email = current_user.email @@ -494,7 +499,7 @@ class OwnerTransfer(Resource): # check if the current user is the owner of the workspace if not current_user.current_tenant: raise ValueError("No current tenant") - if not TenantService.is_owner(current_user, current_user.current_tenant): + if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session): raise NotOwnerError() if current_user.id == str(member_id): @@ -516,12 +521,14 @@ class OwnerTransfer(Resource): if not current_user.current_tenant: raise ValueError("No current tenant") - if not TenantService.is_member(member, current_user.current_tenant): + if not TenantService.is_member(member, current_user.current_tenant, session=db.session): raise MemberNotInTenantError() try: assert member is not None, "Member not found" - TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user) + TenantService.update_member_role( + current_user.current_tenant, member, "owner", current_user, session=db.session + ) AccountService.send_new_owner_transfer_notify_email( account=member, diff --git a/api/controllers/console/workspace/rbac.py b/api/controllers/console/workspace/rbac.py index f672833061a..c3a3420b908 100644 --- a/api/controllers/console/workspace/rbac.py +++ b/api/controllers/console/workspace/rbac.py @@ -201,21 +201,23 @@ def _legacy_workspace_roles( This keeps the new `/rbac/roles` endpoint compatible with the original Dify role model when enterprise RBAC is disabled. """ - - legacy_roles = [ - svc.RBACRole( - id=role_name, - tenant_id="", - type=svc.RBACRoleType.WORKSPACE.value, - category="global_system_default", - name=role_name, - description="", - is_builtin=True, - permission_keys=list(dict.fromkeys(_LEGACY_ROLE_PERMISSION_KEYS[role_name])), - role_tag="owner" if role_name == "owner" else "", + legacy_roles = [] + for role_name in ("owner", "admin", "editor", "normal", "dataset_operator"): + if not dify_config.DATASET_OPERATOR_ENABLED and role_name == "dataset_operator": + continue + legacy_roles.append( + svc.RBACRole( + id=role_name, + tenant_id="", + type=svc.RBACRoleType.WORKSPACE.value, + category="global_system_default", + name=role_name, + description="", + is_builtin=True, + permission_keys=list(dict.fromkeys(_LEGACY_ROLE_PERMISSION_KEYS[role_name])), + role_tag="owner" if role_name == "owner" else "", + ) ) - for role_name in ("owner", "admin", "editor", "normal", "dataset_operator") - ] if not include_owner: legacy_roles = [r for r in legacy_roles if r.name != "owner"] diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 59a33fe0385..0afd7e06bf7 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -325,10 +325,10 @@ class TenantApi(Resource): raise ValueError("No current tenant") if tenant.status == TenantStatus.ARCHIVE: - tenants = TenantService.get_join_tenants(current_user) + tenants = TenantService.get_join_tenants(current_user, session=db.session) # if there is any tenant, switch to the first one if len(tenants) > 0: - TenantService.switch_tenant(current_user, tenants[0].id) + TenantService.switch_tenant(current_user, tenants[0].id, session=db.session) tenant = tenants[0] # else, raise Unauthorized else: @@ -351,7 +351,7 @@ class SwitchWorkspaceApi(Resource): # check if tenant_id is valid, 403 if not try: - TenantService.switch_tenant(current_user, args.tenant_id) + TenantService.switch_tenant(current_user, args.tenant_id, session=db.session) except Exception: raise AccountNotLinkTenantError("Account not link tenant") diff --git a/api/controllers/inner_api/knowledge/retrieval.py b/api/controllers/inner_api/knowledge/retrieval.py index 1c1320fde42..e34dedea286 100644 --- a/api/controllers/inner_api/knowledge/retrieval.py +++ b/api/controllers/inner_api/knowledge/retrieval.py @@ -14,6 +14,7 @@ from controllers.common.schema import register_response_schema_models, register_ from controllers.inner_api import inner_api_ns from controllers.inner_api.wraps import plugin_inner_api_only from core.workflow.nodes.knowledge_retrieval import exc as retrieval_exc +from extensions.ext_database import db from libs.exception import BaseHTTPException from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveRequest, InnerKnowledgeRetrieveResponse from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError, InnerKnowledgeRetrievalServiceError @@ -81,7 +82,7 @@ class InnerKnowledgeRetrieveApi(Resource): ) from exc try: - response = InnerKnowledgeRetrievalService().retrieve(payload) + response = InnerKnowledgeRetrievalService().retrieve(payload, session=db.session) except InnerKnowledgeRetrievalServiceError as exc: raise InnerKnowledgeRetrievalHttpError( error_code=exc.error_code, diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index dd93616e6b1..1f25eb576d3 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -47,7 +47,7 @@ class EnterpriseWorkspace(Resource): if account is None: return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args.name, is_from_dashboard=True) + tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session) TenantService.create_tenant_member(tenant, account, db.session, role="owner") tenant_was_created.send(tenant) @@ -84,7 +84,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): def post(self): args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {}) - tenant = TenantService.create_tenant(args.name, is_from_dashboard=True) + tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session) tenant_was_created.send(tenant) diff --git a/api/controllers/openapi/workspaces.py b/api/controllers/openapi/workspaces.py index 0ff225271df..49f8fb9656f 100644 --- a/api/controllers/openapi/workspaces.py +++ b/api/controllers/openapi/workspaces.py @@ -128,7 +128,7 @@ class WorkspaceSwitchApi(Resource): account = _load_account(auth_data.account_id) try: - TenantService.switch_tenant(account, workspace_id) + TenantService.switch_tenant(account, workspace_id, session=db.session) except AccountNotLinkTenantError: raise NotFound("workspace not found") @@ -152,7 +152,7 @@ class WorkspaceMembersApi(Resource): @accepts(query=MemberListQuery) def get(self, workspace_id: str, *, auth_data: AuthData, query: MemberListQuery): tenant = _load_tenant(workspace_id) - members = TenantService.get_tenant_members(tenant) + members = TenantService.get_tenant_members(tenant, session=db.session) total = len(members) start = (query.page - 1) * query.limit page_items = members[start : start + query.limit] @@ -184,6 +184,7 @@ class WorkspaceMembersApi(Resource): language=None, role=body.role, inviter=inviter, + session=db.session, ) except AccountAlreadyInTenantError as exc: raise BadRequest(str(exc)) @@ -193,7 +194,7 @@ class WorkspaceMembersApi(Resource): raise BadRequest(str(exc)) normalized_email = body.email.lower() - member = AccountService.get_account_by_email_with_case_fallback(normalized_email) + member = AccountService.get_account_by_email_with_case_fallback(db.session, normalized_email) if member is None: # invite_new_member just created or fetched this account. raise RuntimeError("invited member missing from DB after invite") @@ -232,7 +233,7 @@ class WorkspaceMemberApi(Resource): raise NotFound("member not found") try: - TenantService.remove_member_from_tenant(tenant, member, operator) + TenantService.remove_member_from_tenant(tenant, member, operator, session=db.session) except CannotOperateSelfError as exc: raise BadRequest(str(exc)) except NoPermissionError as exc: @@ -266,7 +267,7 @@ class WorkspaceMemberRoleApi(Resource): raise NotFound("member not found") try: - TenantService.update_member_role(tenant, member, body.role, operator) + TenantService.update_member_role(tenant, member, body.role, operator, session=db.session) except CannotOperateSelfError as exc: raise BadRequest(str(exc)) except NoPermissionError as exc: diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index 627545d7168..8a57ec9818a 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -10,6 +10,7 @@ from controllers.common.schema import query_params_from_model, register_response from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_app_token +from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.annotation_fields import Annotation, AnnotationList from fields.base import ResponseModel @@ -281,7 +282,9 @@ class AnnotationUpdateDeleteApi(Resource): """Update an existing annotation.""" payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}) update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer} - annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id)) + annotation = AppAnnotationService.update_app_annotation_directly( + update_args, app_model.id, str(annotation_id), db.session + ) response = Annotation.model_validate(annotation, from_attributes=True) return response.model_dump(mode="json") @@ -310,5 +313,5 @@ class AnnotationUpdateDeleteApi(Resource): @edit_permission_required def delete(self, app_model: App, annotation_id: UUID): """Delete an annotation.""" - AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id)) + AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id), db.session) return "", 204 diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 2b5a9ba83a1..59ed4b4a4b1 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -23,6 +23,7 @@ from controllers.service_api.app.error import ( from controllers.service_api.schema import binary_response, expect_with_user, multipart_file_params from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService @@ -177,7 +178,12 @@ class TextApi(Resource): text = payload.text voice = payload.voice response = AudioService.transcript_tts( - app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id + app_model=app_model, + session=db.session, + text=text, + voice=voice, + end_user=end_user.external_user_id, + message_id=message_id, ) return response diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 426d008c412..7363e6bdfd4 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -8,6 +8,7 @@ from controllers.common.controller_schemas import MetadataUpdatePayload from controllers.common.schema import register_response_schema_models, register_schema_model, register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check +from extensions.ext_database import db from fields.dataset_fields import ( DatasetMetadataActionResponse, DatasetMetadataBuiltInFieldsResponse, @@ -85,7 +86,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) + metadata = MetadataService.create_metadata(db.session(), dataset_id_str, metadata_args) return dump_response(DatasetMetadataResponse, metadata), 201 @service_api_ns.doc( @@ -118,7 +119,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - metadata = MetadataService.get_dataset_metadatas(dataset) + metadata = MetadataService.get_dataset_metadatas(db.session(), dataset) return dump_response(DatasetMetadataListResponse, metadata), 200 @@ -158,7 +159,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name) + metadata = MetadataService.update_metadata_name(db.session(), dataset_id_str, metadata_id_str, payload.name) return dump_response(DatasetMetadataResponse, metadata), 200 @service_api_ns.doc( @@ -193,7 +194,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - MetadataService.delete_metadata(dataset_id_str, metadata_id_str) + MetadataService.delete_metadata(db.session(), dataset_id_str, metadata_id_str) return "", 204 @@ -263,9 +264,9 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): match action: case "enable": - MetadataService.enable_built_in_field(dataset) + MetadataService.enable_built_in_field(db.session(), dataset) case "disable": - MetadataService.disable_built_in_field(dataset) + MetadataService.disable_built_in_field(db.session(), dataset) return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200 @@ -309,6 +310,6 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {}) - MetadataService.update_documents_metadata(dataset, metadata_args) + MetadataService.update_documents_metadata(db.session(), dataset, metadata_args) return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200 diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index c762c914861..801c1f5a629 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -22,6 +22,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App, EndUser @@ -130,7 +131,12 @@ class TextApi(WebApiResource): text = payload.text voice = payload.voice response = AudioService.transcript_tts( - app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id + app_model=app_model, + session=db.session, + text=text, + voice=voice, + end_user=end_user.external_user_id, + message_id=message_id, ) return response diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 2bb7db015ad..7871b411c4b 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -32,6 +32,7 @@ from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService +from services.conversation_service import ConversationService from services.errors.llm import InvokeRateLimitError logger = logging.getLogger(__name__) @@ -202,6 +203,12 @@ class ChatApi(WebApiResource): args["auto_generate_name"] = False try: + # Eagerly validate conversation to avoid hanging on invalid conversation_id + if payload.conversation_id: + ConversationService.get_conversation( + app_model=app_model, conversation_id=payload.conversation_id, user=end_user + ) + response = AppGenerateService.generate( app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py index d0e023e40ee..ecc91113c32 100644 --- a/api/controllers/web/forgot_password.py +++ b/api/controllers/web/forgot_password.py @@ -69,7 +69,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - account = AccountService.get_account_by_email_with_case_fallback(request_email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, request_email) if account is None: raise AuthenticationFailedError() else: @@ -168,7 +168,7 @@ class ForgotPasswordResetApi(Resource): email = reset_data.get("email", "") - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if account: account = db.session.merge(account) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index d8ca19b5fc3..c0d6952871e 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -41,6 +41,7 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, QueuePingEvent, + QueueReasoningChunkEvent, QueueRetrieverResourcesEvent, QueueStopEvent, QueueTextChunkEvent, @@ -62,6 +63,7 @@ from core.app.entities.task_entities import ( MessageAudioStreamResponse, MessageEndStreamResponse, PingStreamResponse, + ReasoningChunkStreamResponse, StreamResponse, WorkflowPauseStreamResponse, WorkflowTaskState, @@ -473,6 +475,17 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) ) + # Collect terminal reasoning (separated mode) per LLM node id for persistence. This is the + # authoritative source (outputs.reasoning_content), decoupled from the live delta stream. + # Accumulate across iteration/loop passes (same node_id) to match the live stream, which + # appends every pass under the same key — overwriting would keep only the last pass. + if event.node_type == BuiltinNodeTypes.LLM: + reasoning_content = (event.outputs or {}).get("reasoning_content") + if isinstance(reasoning_content, str) and reasoning_content: + self._task_state.metadata.reasoning[event.node_id] = ( + self._task_state.metadata.reasoning.get(event.node_id, "") + reasoning_content + ) + node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, @@ -535,6 +548,27 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) + def _handle_reasoning_chunk_event( + self, event: QueueReasoningChunkEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle out-of-band reasoning chunk events. + + Pure emit: reasoning is streamed on its own channel and never written to the + answer. The terminal marker (is_final) may carry an empty reasoning string, in + which case it is still forwarded as the "thinking finished" signal. + """ + if not event.reasoning and not event.is_final: + return + yield ReasoningChunkStreamResponse( + task_id=self._application_generate_entity.task_id, + data=ReasoningChunkStreamResponse.Data( + message_id=self._message_id, + reasoning=event.reasoning, + node_id=event.from_node_id, + is_final=event.is_final, + ), + ) + def _handle_iteration_start_event( self, event: QueueIterationStartEvent, **kwargs ) -> Generator[StreamResponse, None, None]: @@ -872,6 +906,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueuePingEvent: self._handle_ping_event, QueueErrorEvent: self._handle_error_event, QueueTextChunkEvent: self._handle_text_chunk_event, + QueueReasoningChunkEvent: self._handle_reasoning_chunk_event, # Workflow events QueueWorkflowStartedEvent: self._handle_workflow_started_event, QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, diff --git a/api/core/app/apps/agent_app/runtime_request_builder.py b/api/core/app/apps/agent_app/runtime_request_builder.py index fc1fcb0b168..9790f2fbca0 100644 --- a/api/core/app/apps/agent_app/runtime_request_builder.py +++ b/api/core/app/apps/agent_app/runtime_request_builder.py @@ -197,7 +197,7 @@ class AgentAppRuntimeRequestBuilder: def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str: """Return the transport plugin id expected by plugin-daemon headers.""" if plugin_id.count("/") == 1: - return plugin_id + return plugin_id.split(":", 1)[0].split("@", 1)[0] if plugin_id: return ModelProviderID(plugin_id).plugin_id return ModelProviderID(model_provider).plugin_id diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index e52e1e9c9da..6ca4053d7de 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -31,6 +31,7 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, QueuePingEvent, + QueueReasoningChunkEvent, QueueStopEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, @@ -47,6 +48,7 @@ from core.app.entities.task_entities import ( MessageAudioEndStreamResponse, MessageAudioStreamResponse, PingStreamResponse, + ReasoningChunkStreamResponse, StreamResponse, TextChunkStreamResponse, WorkflowAppBlockingResponse, @@ -571,6 +573,22 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector) + def _handle_reasoning_chunk_event( + self, event: QueueReasoningChunkEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle reasoning chunk events.""" + # is_final with empty reasoning is still forwarded as the "thinking finished" signal + if not event.reasoning and not event.is_final: + return + yield ReasoningChunkStreamResponse( + task_id=self._application_generate_entity.task_id, + data=ReasoningChunkStreamResponse.Data( + reasoning=event.reasoning, + node_id=event.from_node_id, + is_final=event.is_final, + ), + ) + def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: """Handle agent log events.""" yield self._workflow_response_converter.handle_agent_log( @@ -600,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): QueuePingEvent: self._handle_ping_event, QueueErrorEvent: self._handle_error_event, QueueTextChunkEvent: self._handle_text_chunk_event, + QueueReasoningChunkEvent: self._handle_reasoning_chunk_event, # Workflow events QueueWorkflowStartedEvent: self._handle_workflow_started_event, QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 69f6c5b69b7..3dae32b742b 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -24,6 +24,7 @@ from core.app.entities.queue_entities import ( QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueReasoningChunkEvent, QueueRetrieverResourcesEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, @@ -74,6 +75,7 @@ from graphon.graph_events import ( NodeRunLoopNextEvent, NodeRunLoopStartedEvent, NodeRunLoopSucceededEvent, + NodeRunReasoningChunkEvent, NodeRunRetrieverResourceEvent, NodeRunRetryEvent, NodeRunStartedEvent, @@ -576,6 +578,16 @@ class WorkflowBasedAppRunner: in_loop_id=event.in_loop_id, ) ) + case NodeRunReasoningChunkEvent(): + self._publish_event( + QueueReasoningChunkEvent( + reasoning=event.chunk, + from_node_id=event.node_id, + is_final=event.is_final, + in_iteration_id=event.in_iteration_id, + in_loop_id=event.in_loop_id, + ) + ) case NodeRunRetrieverResourceEvent(): self._publish_event( QueueRetrieverResourcesEvent( diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index a0e7881edeb..102d92ccf9e 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -40,6 +40,7 @@ class QueueEvent(StrEnum): NODE_FAILED = "node_failed" NODE_EXCEPTION = "node_exception" RETRIEVER_RESOURCES = "retriever_resources" + REASONING_CHUNK = "reasoning_chunk" ANNOTATION_REPLY = "annotation_reply" AGENT_THOUGHT = "agent_thought" MESSAGE_FILE = "message_file" @@ -197,6 +198,26 @@ class QueueTextChunkEvent(AppQueueEvent): """loop id if node is in loop""" +class QueueReasoningChunkEvent(AppQueueEvent): + """ + QueueReasoningChunkEvent entity + + Out-of-band reasoning (chain-of-thought) delta from an LLM node in "separated" + mode. It never touches the answer; it is emitted on a dedicated channel. + """ + + event: QueueEvent = QueueEvent.REASONING_CHUNK + reasoning: str + from_node_id: str | None = None + """id of the LLM node that produced this reasoning""" + is_final: bool = False + """marks the terminal reasoning chunk for the node run (may carry empty reasoning)""" + in_iteration_id: str | None = None + """iteration id if node is in iteration""" + in_loop_id: str | None = None + """loop id if node is in loop""" + + class QueueAgentMessageEvent(AppQueueEvent): """ QueueMessageEvent entity diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 3a8107e0461..ca5a26db55b 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -27,6 +27,9 @@ class TaskStateMetadata(BaseModel): annotation_reply: AnnotationReply | None = None retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list) usage: LLMUsage | None = None + reasoning: dict[str, str] = Field(default_factory=dict) + """reasoning_content per LLM node id (separated mode), accumulated across iteration/loop + passes for that node; persisted to message_metadata""" class TaskState(BaseModel): @@ -85,6 +88,7 @@ class StreamEvent(StrEnum): LOOP_COMPLETED = "loop_completed" TEXT_CHUNK = "text_chunk" TEXT_REPLACE = "text_replace" + REASONING_CHUNK = "reasoning_chunk" AGENT_LOG = "agent_log" HUMAN_INPUT_REQUIRED = "human_input_required" HUMAN_INPUT_FORM_FILLED = "human_input_form_filled" @@ -726,6 +730,29 @@ class TextChunkStreamResponse(StreamResponse): data: Data +class ReasoningChunkStreamResponse(StreamResponse): + """ + ReasoningChunkStreamResponse entity + + Out-of-band reasoning (chain-of-thought) delta, parallel to text_chunk. Only + emitted in "separated" mode; the answer/message stream stays free of . + """ + + class Data(BaseModel): + """ + Data entity + """ + + # chat apps set this; workflow runs have no message + message_id: str | None = None + reasoning: str + node_id: str | None = None + is_final: bool = False + + event: StreamEvent = StreamEvent.REASONING_CHUNK + data: Data + + class TextReplaceStreamResponse(StreamResponse): """ TextReplaceStreamResponse entity diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index f68e5a4e6b3..dd173207b09 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -226,7 +226,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): all_multimodal_documents.append(file_document) doc.attachments = attachments else: - account = AccountService.load_user(document.created_by) + account = AccountService.load_user(document.created_by, db.session) if not account: raise ValueError("Invalid account") doc.attachments = self._get_content_files(doc, current_user=account) diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 9c186a9f046..78d8b7dcd53 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -291,7 +291,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): attachments.append(file_document) doc.attachments = attachments else: - account = AccountService.load_user(document.created_by) + account = AccountService.load_user(document.created_by, db.session) if not account: raise ValueError("Invalid account") doc.attachments = self._get_content_files(doc, current_user=account) diff --git a/api/core/workflow/nodes/agent_v2/runtime_request_builder.py b/api/core/workflow/nodes/agent_v2/runtime_request_builder.py index e3c2dcee839..e5a541ed350 100644 --- a/api/core/workflow/nodes/agent_v2/runtime_request_builder.py +++ b/api/core/workflow/nodes/agent_v2/runtime_request_builder.py @@ -265,7 +265,7 @@ class WorkflowAgentRuntimeRequestBuilder: def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str: """Return the transport plugin id expected by plugin-daemon headers.""" if plugin_id.count("/") == 1: - return plugin_id + return plugin_id.split(":", 1)[0].split("@", 1)[0] if plugin_id: return ModelProviderID(plugin_id).plugin_id return ModelProviderID(model_provider).plugin_id diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index fc118df5bc0..14778f49d5e 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -124,6 +124,7 @@ else exec python -m app else exec gunicorn \ + --no-control-socket \ --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ --workers ${SERVER_WORKER_AMOUNT:-1} \ --worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \ diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 0ae018f6a1d..f6496c70a78 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -84,7 +84,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non if not user_id: raise Unauthorized("Invalid Authorization token.") - logged_in_account = AccountService.load_logged_in_account(account_id=user_id) + logged_in_account = AccountService.load_logged_in_account(account_id=user_id, session=db.session) return logged_in_account elif request.blueprint == "openapi": # Account-branch device-flow approval routes (approve / deny / @@ -103,7 +103,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non source = decoded.get("token_source") if source or not user_id: return None - return AccountService.load_logged_in_account(account_id=user_id) + return AccountService.load_logged_in_account(account_id=user_id, session=db.session) elif request.blueprint == "web": app_code = request.headers.get(HEADER_NAME_APP_CODE) webapp_token = extract_webapp_passport(app_code, request) if app_code else None diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 3f9c5bf0521..ba63b8bd1f4 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -77,6 +77,13 @@ class WebMessageListItem(MessageListItem): ) +class ExploreMessageListItem(MessageListItem): + metadata: JSONValueType | None = Field( + default=None, + validation_alias="message_metadata_dict", + ) + + class MessageInfiniteScrollPagination(ResponseModel): limit: int has_more: bool @@ -89,6 +96,12 @@ class WebMessageInfiniteScrollPagination(ResponseModel): data: list[WebMessageListItem] +class ExploreMessageInfiniteScrollPagination(ResponseModel): + limit: int + has_more: bool + data: list[ExploreMessageListItem] + + class SavedMessageItem(ResponseModel): id: str inputs: dict[str, JSONValueType] diff --git a/api/libs/password.py b/api/libs/password.py index 3313278492a..4170abbf225 100644 --- a/api/libs/password.py +++ b/api/libs/password.py @@ -16,7 +16,7 @@ def valid_password(password): raise ValueError("Password must contain letters and numbers, and the length must be at least 8 characters.") -def hash_password(password_str, salt_byte): +def hash_password(password_str: str, salt_byte: bytes): dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000) return binascii.hexlify(dk) diff --git a/api/migrations/versions/2026_06_17_0900-e4f8a2c61d35_add_dataset_id_to_api_tokens.py b/api/migrations/versions/2026_06_17_0900-e4f8a2c61d35_add_dataset_id_to_api_tokens.py index 161961da4aa..844a8d6c932 100644 --- a/api/migrations/versions/2026_06_17_0900-e4f8a2c61d35_add_dataset_id_to_api_tokens.py +++ b/api/migrations/versions/2026_06_17_0900-e4f8a2c61d35_add_dataset_id_to_api_tokens.py @@ -1,7 +1,7 @@ """add dataset_id to api_tokens Revision ID: e4f8a2c61d35 -Revises: c8f4a6b2d3e1 +Revises: d9e8f7a6b5c4 Create Date: 2026-06-17 09:00:00.000000 Reintroduces the nullable `dataset_id` column on `api_tokens` (it was dropped in @@ -23,7 +23,7 @@ import models as models # revision identifiers, used by Alembic. revision = "e4f8a2c61d35" -down_revision = "c8f4a6b2d3e1" +down_revision = "d9e8f7a6b5c4" branch_labels = None depends_on = None diff --git a/api/migrations/versions/2026_06_23_1800-d9e8f7a6b5c4_add_cloud_only_flag_to_recommended_apps.py b/api/migrations/versions/2026_06_23_1800-d9e8f7a6b5c4_add_cloud_only_flag_to_recommended_apps.py new file mode 100644 index 00000000000..77bf5118bec --- /dev/null +++ b/api/migrations/versions/2026_06_23_1800-d9e8f7a6b5c4_add_cloud_only_flag_to_recommended_apps.py @@ -0,0 +1,26 @@ +"""add cloud only flag to recommended apps + +Revision ID: d9e8f7a6b5c4 +Revises: c8f4a6b2d3e1 +Create Date: 2026-06-23 18:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "d9e8f7a6b5c4" +down_revision = "c8f4a6b2d3e1" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("recommended_apps", schema=None) as batch_op: + batch_op.add_column(sa.Column("is_cloud_only", sa.Boolean(), server_default=sa.text("false"), nullable=False)) + + +def downgrade(): + with op.batch_alter_table("recommended_apps", schema=None) as batch_op: + batch_op.drop_column("is_cloud_only") diff --git a/api/models/model.py b/api/models/model.py index 0d0b51f65cc..82034010092 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -925,6 +925,9 @@ class RecommendedApp(TypeBase): is_learn_dify: Mapped[bool] = mapped_column( sa.Boolean, nullable=False, server_default=sa.text("false"), default=False ) + is_cloud_only: Mapped[bool] = mapped_column( + sa.Boolean, nullable=False, server_default=sa.text("false"), default=False + ) install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) language: Mapped[str] = mapped_column( String(255), diff --git a/api/openapi/markdown/console-openapi.md b/api/openapi/markdown/console-openapi.md index 49a053ea1c4..acb54bf05ca 100644 --- a/api/openapi/markdown/console-openapi.md +++ b/api/openapi/markdown/console-openapi.md @@ -3807,6 +3807,26 @@ Submit human input form preview for workflow | ---- | ----------- | ------ | | 200 | Workflow agent composer candidates | **application/json**: [AgentComposerCandidatesResponse](#agentcomposercandidatesresponse)
| +### [POST] /apps/{app_id}/workflows/draft/nodes/{node_id}/agent-composer/copy-from-roster +#### Parameters + +| Name | Located in | Description | Required | Schema | +| ---- | ---------- | ----------- | -------- | ------ | +| app_id | path | | Yes | string (uuid) | +| node_id | path | | Yes | string | + +#### Request Body + +| Required | Schema | +| -------- | ------ | +| Yes | **application/json**: [WorkflowComposerCopyFromRosterPayload](#workflowcomposercopyfromrosterpayload)
| + +#### Responses + +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Workflow roster agent copied to inline agent | **application/json**: [WorkflowAgentComposerResponse](#workflowagentcomposerresponse)
| + ### [POST] /apps/{app_id}/workflows/draft/nodes/{node_id}/agent-composer/impact #### Parameters @@ -6574,7 +6594,7 @@ Request body: | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Success | **application/json**: [MessageInfiniteScrollPagination](#messageinfinitescrollpagination)
| +| 200 | Success | **application/json**: [ExploreMessageInfiniteScrollPagination](#exploremessageinfinitescrollpagination)
| ### [POST] /installed-apps/{installed_app_id}/messages/{message_id}/feedbacks #### Parameters @@ -14386,9 +14406,14 @@ Button styles for user actions. | agent_soul | [AgentSoulConfig](#agentsoulconfig) | | No | | binding | [ComposerBindingPayload](#composerbindingpayload) | | No | | client_revision_id | string | | No | +| description | string | | No | +| icon | string | | No | +| icon_background | string | | No | +| icon_type | [AgentIconType](#agenticontype) | | No | | idempotency_key | string | | No | | new_agent_name | string | | No | | node_job | [WorkflowNodeJobConfig](#workflownodejobconfig) | | No | +| role | string | | No | | save_strategy | [ComposerSaveStrategy](#composersavestrategy) | | Yes | | soul_lock | [ComposerSoulLockPayload](#composersoullockpayload) | | No | | variant | [ComposerVariant](#composervariant) | | Yes | @@ -15997,6 +16022,34 @@ Request payload for bulk downloading documents as a zip archive. | ---- | ---- | ----------- | -------- | | tool_icons | object | | No | +#### ExploreMessageInfiniteScrollPagination + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| data | [ [ExploreMessageListItem](#exploremessagelistitem) ] | | Yes | +| has_more | boolean | | Yes | +| limit | integer | | Yes | + +#### ExploreMessageListItem + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| agent_thoughts | [ [AgentThought](#agentthought) ] | | Yes | +| answer | string | | Yes | +| conversation_id | string | | Yes | +| created_at | integer | | No | +| error | string | | No | +| extra_contents | [ [HumanInputContent](#humaninputcontent) ] | | Yes | +| feedback | [SimpleFeedback](#simplefeedback) | | No | +| id | string | | Yes | +| inputs | object | | Yes | +| message_files | [ [MessageFile](#messagefile) ] | | Yes | +| metadata | [JSONValueType](#jsonvaluetype) | | No | +| parent_message_id | string | | No | +| query | string | | Yes | +| retriever_resources | [ [RetrieverResource](#retrieverresource) ] | | Yes | +| status | string | | Yes | + #### ExternalApiTemplateListQuery | Name | Type | Description | Required | @@ -17071,14 +17124,6 @@ Enum class for large language model mode. | upload_file_id | string | | No | | url | string | | No | -#### MessageInfiniteScrollPagination - -| Name | Type | Description | Required | -| ---- | ---- | ----------- | -------- | -| data | [ [MessageListItem](#messagelistitem) ] | | Yes | -| has_more | boolean | | Yes | -| limit | integer | | Yes | - #### MessageInfiniteScrollPaginationResponse | Name | Type | Description | Required | @@ -17087,25 +17132,6 @@ Enum class for large language model mode. | has_more | boolean | | Yes | | limit | integer | | Yes | -#### MessageListItem - -| Name | Type | Description | Required | -| ---- | ---- | ----------- | -------- | -| agent_thoughts | [ [AgentThought](#agentthought) ] | | Yes | -| answer | string | | Yes | -| conversation_id | string | | Yes | -| created_at | integer | | No | -| error | string | | No | -| extra_contents | [ [HumanInputContent](#humaninputcontent) ] | | Yes | -| feedback | [SimpleFeedback](#simplefeedback) | | No | -| id | string | | Yes | -| inputs | object | | Yes | -| message_files | [ [MessageFile](#messagefile) ] | | Yes | -| parent_message_id | string | | No | -| query | string | | Yes | -| retriever_resources | [ [RetrieverResource](#retrieverresource) ] | | Yes | -| status | string | | Yes | - #### MessageListQuery | Name | Type | Description | Required | @@ -19580,6 +19606,7 @@ Model class for provider system configuration response. | enable_email_code_login | boolean | | Yes | | enable_email_password_login | boolean,
**Default:** true | | Yes | | enable_explore_banner | boolean | | Yes | +| enable_learn_app | boolean,
**Default:** true | | Yes | | enable_marketplace | boolean | | Yes | | enable_social_oauth_login | boolean | | Yes | | enable_trial_app | boolean | | Yes | @@ -20560,6 +20587,14 @@ How a workflow node is bound to an Agent. | position_x | number | Comment X position | No | | position_y | number | Comment Y position | No | +#### WorkflowComposerCopyFromRosterPayload + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| idempotency_key | string | | No | +| source_agent_id | string | | Yes | +| source_snapshot_id | string | | No | + #### WorkflowConversationVariableResponse | Name | Type | Description | Required | diff --git a/api/openapi/markdown/web-openapi.md b/api/openapi/markdown/web-openapi.md index 569e3706caa..0f368895ab6 100644 --- a/api/openapi/markdown/web-openapi.md +++ b/api/openapi/markdown/web-openapi.md @@ -1603,6 +1603,7 @@ Default configuration for form inputs. | enable_email_code_login | boolean | | Yes | | enable_email_password_login | boolean,
**Default:** true | | Yes | | enable_explore_banner | boolean | | Yes | +| enable_learn_app | boolean,
**Default:** true | | Yes | | enable_marketplace | boolean | | Yes | | enable_social_oauth_login | boolean | | Yes | | enable_trial_app | boolean | | Yes | diff --git a/api/providers/trace/trace-langsmith/pyproject.toml b/api/providers/trace/trace-langsmith/pyproject.toml index 80eb9ae3238..618bec79d0e 100644 --- a/api/providers/trace/trace-langsmith/pyproject.toml +++ b/api/providers/trace/trace-langsmith/pyproject.toml @@ -2,7 +2,7 @@ name = "dify-trace-langsmith" version = "0.0.1" dependencies = [ - "langsmith==0.8.5", + "langsmith==0.8.18", ] description = "Dify ops tracing provider (LangSmith)." diff --git a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py index 54524b09ca6..4d4898b1173 100644 --- a/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py +++ b/api/providers/trace/trace-tencent/tests/unit_tests/tencent_trace/test_tencent_trace.py @@ -181,30 +181,34 @@ class TestTencentDataTrace: mock_trace_utils.convert_to_trace_id.return_value = 123 mock_trace_utils.create_link.return_value = "link" - with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"): - with patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc: - with patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur: - mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()] + with ( + patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"), + patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc, + patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur, + ): + mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()] - tencent_data_trace.workflow_trace(trace_info) + tencent_data_trace.workflow_trace(trace_info) - mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id") - mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") - mock_span_builder.build_workflow_spans.assert_called_once() - assert tencent_data_trace.trace_client.add_span.call_count == 2 - mock_proc.assert_called_once_with(trace_info, 123) - mock_dur.assert_called_once_with(trace_info) + mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id") + mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") + mock_span_builder.build_workflow_spans.assert_called_once() + assert tencent_data_trace.trace_client.add_span.call_count == 2 + mock_proc.assert_called_once_with(trace_info, 123) + mock_dur.assert_called_once_with(trace_info) - def test_workflow_trace_exception(self, tencent_data_trace): + def test_workflow_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.workflow_run_id = "run-id" - with patch( - "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + with ( + patch( + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + ), + caplog.at_level(logging.ERROR), ): - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.workflow_trace(trace_info) - mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace") + tencent_data_trace.workflow_trace(trace_info) + assert "[Tencent APM] Failed to process workflow trace" in caplog.text def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): trace_info = MagicMock(spec=MessageTraceInfo) @@ -214,29 +218,33 @@ class TestTencentDataTrace: mock_trace_utils.convert_to_trace_id.return_value = 123 mock_trace_utils.create_link.return_value = "link" - with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"): - with patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics: - with patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur: - mock_span_builder.build_message_span.return_value = MagicMock() + with ( + patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"), + patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics, + patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur, + ): + mock_span_builder.build_message_span.return_value = MagicMock() - tencent_data_trace.message_trace(trace_info) + tencent_data_trace.message_trace(trace_info) - mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") - mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") - mock_span_builder.build_message_span.assert_called_once() - tencent_data_trace.trace_client.add_span.assert_called_once() - mock_metrics.assert_called_once_with(trace_info) - mock_dur.assert_called_once_with(trace_info) + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") + mock_span_builder.build_message_span.assert_called_once() + tencent_data_trace.trace_client.add_span.assert_called_once() + mock_metrics.assert_called_once_with(trace_info) + mock_dur.assert_called_once_with(trace_info) - def test_message_trace_exception(self, tencent_data_trace): + def test_message_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=MessageTraceInfo) - with patch( - "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + with ( + patch( + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + ), + caplog.at_level(logging.ERROR), ): - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.message_trace(trace_info) - mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace") + tencent_data_trace.message_trace(trace_info) + assert "[Tencent APM] Failed to process message trace" in caplog.text def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): trace_info = MagicMock(spec=ToolTraceInfo) @@ -259,16 +267,18 @@ class TestTencentDataTrace: tencent_data_trace.tool_trace(trace_info) tencent_data_trace.trace_client.add_span.assert_not_called() - def test_tool_trace_exception(self, tencent_data_trace): + def test_tool_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=ToolTraceInfo) trace_info.message_id = "msg-id" - with patch( - "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + with ( + patch( + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + ), + caplog.at_level(logging.ERROR), ): - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.tool_trace(trace_info) - mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace") + tencent_data_trace.tool_trace(trace_info) + assert "[Tencent APM] Failed to process tool trace" in caplog.text def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) @@ -291,29 +301,34 @@ class TestTencentDataTrace: tencent_data_trace.dataset_retrieval_trace(trace_info) tencent_data_trace.trace_client.add_span.assert_not_called() - def test_dataset_retrieval_trace_exception(self, tencent_data_trace): + def test_dataset_retrieval_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) trace_info.message_id = "msg-id" - with patch( - "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + with ( + patch( + "dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + ), + caplog.at_level(logging.ERROR), ): - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.dataset_retrieval_trace(trace_info) - mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace") + tencent_data_trace.dataset_retrieval_trace(trace_info) + assert "[Tencent APM] Failed to process dataset retrieval trace" in caplog.text - def test_suggested_question_trace(self, tencent_data_trace): + def test_suggested_question_trace(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log: + with caplog.at_level(logging.INFO): tencent_data_trace.suggested_question_trace(trace_info) - mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace") + assert "[Tencent APM] Processing suggested question trace" in caplog.text - def test_suggested_question_trace_exception(self, tencent_data_trace): + def test_suggested_question_trace_exception( + self, tencent_data_trace, monkeypatch, caplog: pytest.LogCaptureFixture + ): trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) - with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")): - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: - tencent_data_trace.suggested_question_trace(trace_info) - mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace") + target_logger = logging.getLogger("dify_trace_tencent.tencent_trace") + monkeypatch.setattr(target_logger, "info", MagicMock(side_effect=Exception("error"))) + with caplog.at_level(logging.ERROR): + tencent_data_trace.suggested_question_trace(trace_info) + assert "[Tencent APM] Failed to process suggested question trace" in caplog.text def test_process_workflow_nodes(self, tencent_data_trace, mock_trace_utils): trace_info = MagicMock(spec=WorkflowTraceInfo) @@ -327,35 +342,42 @@ class TestTencentDataTrace: node2.id = "n2" node2.node_type = BuiltinNodeTypes.TOOL - with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]): - with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]): - with patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics: - tencent_data_trace._process_workflow_nodes(trace_info, 123) + with ( + patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]), + patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]), + patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics, + ): + tencent_data_trace._process_workflow_nodes(trace_info, 123) - assert tencent_data_trace.trace_client.add_span.call_count == 2 - mock_metrics.assert_called_once_with(node1) + assert tencent_data_trace.trace_client.add_span.call_count == 2 + mock_metrics.assert_called_once_with(node1) - def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils): + def test_process_workflow_nodes_node_exception( + self, tencent_data_trace, mock_trace_utils, caplog: pytest.LogCaptureFixture + ): trace_info = MagicMock(spec=WorkflowTraceInfo) mock_trace_utils.convert_to_span_id.return_value = 111 node = MagicMock(spec=WorkflowNodeExecution) node.id = "n1" - with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): - with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: - tencent_data_trace._process_workflow_nodes(trace_info, 123) - # The exception should be caught by the outer handler since convert_to_span_id is called first - mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") + with ( + patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]), + patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")), + caplog.at_level(logging.ERROR), + ): + tencent_data_trace._process_workflow_nodes(trace_info, 123) + assert "[Tencent APM] Failed to process workflow nodes" in caplog.text - def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils): + def test_process_workflow_nodes_exception( + self, tencent_data_trace, mock_trace_utils, caplog: pytest.LogCaptureFixture + ): trace_info = MagicMock(spec=WorkflowTraceInfo) mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: + with caplog.at_level(logging.ERROR): tencent_data_trace._process_workflow_nodes(trace_info, 123) - mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") + assert "[Tencent APM] Failed to process workflow nodes" in caplog.text def test_build_workflow_node_span(self, tencent_data_trace, mock_span_builder): trace_info = MagicMock(spec=WorkflowTraceInfo) @@ -377,16 +399,18 @@ class TestTencentDataTrace: assert result == "span" builder_method.assert_called_once_with(123, 456, trace_info, node) - def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder): + def test_build_workflow_node_span_exception( + self, tencent_data_trace, mock_span_builder, caplog: pytest.LogCaptureFixture + ): node = MagicMock(spec=WorkflowNodeExecution) node.node_type = BuiltinNodeTypes.LLM node.id = "n1" mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") - with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: + with caplog.at_level(logging.DEBUG): result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456) - assert result is None - mock_log.assert_called_once() + assert result is None + assert len([r for r in caplog.records if r.levelno == logging.DEBUG]) >= 1 def test_get_workflow_node_executions(self, tencent_data_trace): trace_info = MagicMock(spec=WorkflowTraceInfo) @@ -419,16 +443,16 @@ class TestTencentDataTrace: assert results == mock_executions account.set_tenant_id.assert_called_once_with("tenant-1") - def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace): + def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {} - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: + with caplog.at_level(logging.ERROR): results = tencent_data_trace._get_workflow_node_executions(trace_info) - assert results == [] - mock_log.assert_called_once() + assert results == [] + assert len([r for r in caplog.records if r.levelno == logging.ERROR]) >= 1 - def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace): + def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.metadata = {"app_id": "app-1"} @@ -439,23 +463,25 @@ class TestTencentDataTrace: session = mock_session_ctx.return_value.__enter__.return_value session.scalar.return_value = None - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: + with caplog.at_level(logging.ERROR): results = tencent_data_trace._get_workflow_node_executions(trace_info) - assert results == [] - mock_log.assert_called_once() + assert results == [] + assert len([r for r in caplog.records if r.levelno == logging.ERROR]) >= 1 def test_get_user_id_workflow(self, tencent_data_trace): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.tenant_id = "tenant-1" trace_info.metadata = {"user_id": "user-1"} - with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")): - with patch("dify_trace_tencent.tencent_trace.db") as mock_db: - mock_db.init_app = MagicMock() - mock_db.engine = MagicMock() + with ( + patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")), + patch("dify_trace_tencent.tencent_trace.db") as mock_db, + ): + mock_db.init_app = MagicMock() + mock_db.engine = MagicMock() - user_id = tencent_data_trace._get_user_id(trace_info) - assert user_id == "unknown" + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "unknown" def test_get_user_id_only_user_id(self, tencent_data_trace): trace_info = MagicMock(spec=MessageTraceInfo) @@ -471,16 +497,18 @@ class TestTencentDataTrace: user_id = tencent_data_trace._get_user_id(trace_info) assert user_id == "anonymous" - def test_get_user_id_exception(self, tencent_data_trace): + def test_get_user_id_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.tenant_id = "t" trace_info.metadata = {"user_id": "u"} - with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")): - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: - user_id = tencent_data_trace._get_user_id(trace_info) - assert user_id == "unknown" - mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID") + with ( + patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")), + caplog.at_level(logging.ERROR), + ): + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "unknown" + assert "[Tencent APM] Failed to get user ID" in caplog.text def test_record_llm_metrics_usage_in_process_data(self, tencent_data_trace): node = MagicMock(spec=WorkflowNodeExecution) @@ -514,14 +542,14 @@ class TestTencentDataTrace: tencent_data_trace.trace_client.record_llm_duration.assert_called_once() tencent_data_trace.trace_client.record_token_usage.assert_called_once() - def test_record_llm_metrics_exception(self, tencent_data_trace): + def test_record_llm_metrics_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): node = MagicMock(spec=WorkflowNodeExecution) node.process_data = None node.outputs = None - with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: + with caplog.at_level(logging.DEBUG): tencent_data_trace._record_llm_metrics(node) - # Should not crash + # Should not crash def test_record_message_llm_metrics(self, tencent_data_trace): trace_info = MagicMock(spec=MessageTraceInfo) @@ -553,13 +581,13 @@ class TestTencentDataTrace: tencent_data_trace._record_message_llm_metrics(trace_info) tencent_data_trace.trace_client.record_llm_duration.assert_called_once() - def test_record_message_llm_metrics_exception(self, tencent_data_trace): + def test_record_message_llm_metrics_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=MessageTraceInfo) trace_info.metadata = None - with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: + with caplog.at_level(logging.DEBUG): tencent_data_trace._record_message_llm_metrics(trace_info) - # Should not crash + # Should not crash def test_record_workflow_trace_duration(self, tencent_data_trace): trace_info = MagicMock(spec=WorkflowTraceInfo) @@ -605,11 +633,11 @@ class TestTencentDataTrace: attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {} assert attributes["has_conversation"] == "false" - def test_record_workflow_trace_duration_exception(self, tencent_data_trace): + def test_record_workflow_trace_duration_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=WorkflowTraceInfo) trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right - with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: + with caplog.at_level(logging.DEBUG): tencent_data_trace._record_workflow_trace_duration(trace_info) def test_record_message_trace_duration(self, tencent_data_trace): @@ -627,11 +655,11 @@ class TestTencentDataTrace: 2.0, {"conversation_mode": "chat", "stream": "true"} ) - def test_record_message_trace_duration_exception(self, tencent_data_trace): + def test_record_message_trace_duration_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): trace_info = MagicMock(spec=MessageTraceInfo) trace_info.start_time = None - with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log: + with caplog.at_level(logging.DEBUG): tencent_data_trace._record_message_trace_duration(trace_info) def test_close(self, tencent_data_trace): @@ -647,11 +675,11 @@ class TestTencentDataTrace: client.shutdown.assert_called_once() - def test_close_exception(self, tencent_data_trace): + def test_close_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture): tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") - with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log: + with caplog.at_level(logging.ERROR): tencent_data_trace.close() - mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") + assert "[Tencent APM] Failed to shutdown trace client during cleanup" in caplog.text def test_close_handles_async_shutdown_mock(self, tencent_data_trace): shutdown = AsyncMock() diff --git a/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py b/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py index 0e37cf6e4a3..d066aafff94 100644 --- a/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py +++ b/api/providers/vdb/vdb-lindorm/src/dify_vdb_lindorm/lindorm_vector.py @@ -113,7 +113,7 @@ class LindormVectorStore(BaseVector): ) def _bulk_with_retry(actions): try: - response = self._client.bulk(actions, timeout=timeout) + response = self._client.bulk(body=actions, timeout=timeout) if response["errors"]: error_items = [item for item in response["items"] if "error" in item["index"]] error_msg = f"Bulk indexing had {len(error_items)} errors" @@ -231,7 +231,7 @@ class LindormVectorStore(BaseVector): routing_filter_query = { "query": {"bool": {"must": [{"term": {f"{ROUTING_FIELD}.keyword": self._routing}}]}} } - self._client.delete_by_query(self._collection_name, body=routing_filter_query) + self._client.delete_by_query(index=self._collection_name, body=routing_filter_query) self.refresh() else: if self._client.indices.exists(index=self._collection_name): diff --git a/api/providers/vdb/vdb-lindorm/tests/unit_tests/test_lindorm_vector.py b/api/providers/vdb/vdb-lindorm/tests/unit_tests/test_lindorm_vector.py index 4a408d1b101..182a4f4ea97 100644 --- a/api/providers/vdb/vdb-lindorm/tests/unit_tests/test_lindorm_vector.py +++ b/api/providers/vdb/vdb-lindorm/tests/unit_tests/test_lindorm_vector.py @@ -127,7 +127,7 @@ def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch: pytes vector.add_texts(docs, embeddings, batch_size=2, timeout=9) assert vector._client.bulk.call_count == 2 - actions = vector._client.bulk.call_args_list[0].args[0] + actions = vector._client.bulk.call_args_list[0].kwargs["body"] assert actions[0]["index"]["routing"] == "route" assert actions[1][lindorm_module.ROUTING_FIELD] == "route" vector.refresh() diff --git a/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py index 8aea8634fbe..79c06ea6028 100644 --- a/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py +++ b/api/providers/vdb/vdb-weaviate/tests/unit_tests/test_weaviate_vector.py @@ -268,9 +268,11 @@ class TestWeaviateVector(unittest.TestCase): wv._client = MagicMock() wv._client.collections.exists.side_effect = RuntimeError("create failed") - with patch.object(weaviate_vector_module.logger, "exception") as mock_exception: - with pytest.raises(RuntimeError, match="create failed"): - wv._create_collection() + with ( + patch.object(weaviate_vector_module.logger, "exception") as mock_exception, + pytest.raises(RuntimeError, match="create failed"), + ): + wv._create_collection() mock_exception.assert_called_once() @@ -835,9 +837,11 @@ class TestWeaviateVector(unittest.TestCase): wv._client.collections.use.return_value = mock_col mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500) - with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError): - with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"): - wv.delete_by_ids(["bad-id"]) + with ( + patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError), + pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"), + ): + wv.delete_by_ids(["bad-id"]) def test_json_serializable_converts_datetime(self): wv = WeaviateVector.__new__(WeaviateVector) diff --git a/api/pyproject.toml b/api/pyproject.toml index 17efcef9db0..e6a214fe700 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -5,7 +5,7 @@ requires-python = "~=3.12.0" dependencies = [ # Legacy: mature and widely deployed - "bleach>=6.3.0,<7.0.0", + "bleach>=6.4.0,<7.0.0", "boto3>=1.43.24,<2.0.0", "celery>=5.6.3,<6.0.0", "croniter>=6.2.2,<7.0.0", @@ -44,7 +44,7 @@ dependencies = [ "resend>=2.27.0,<3.0.0", # Emerging: newer and fast-moving, use compatible pins "fastopenapi[flask]==0.7.0", - "graphon==0.5.2", + "graphon==0.5.3", "httpx-sse==0.4.3", "json-repair==0.59.4", ] @@ -103,7 +103,11 @@ dify-trace-weave = { workspace = true } [tool.uv] default-groups = ["storage", "tools", "vdb-all", "trace-all"] package = false -override-dependencies = ["litellm>=1.83.10,<2.0.0", "pyarrow>=23.0.1,<24.0.0"] +override-dependencies = [ + "litellm>=1.83.10,<2.0.0", + "pyarrow>=23.0.1,<24.0.0", + "cryptography>=49.0.0,<50.0.0", +] [dependency-groups] diff --git a/api/services/account_service.py b/api/services/account_service.py index 21b5f1eedba..445b1acd3b9 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1,3 +1,10 @@ +"""Account, workspace, and invitation services. + +Database access in this module is caller-scoped: methods that read or mutate ORM state accept an explicit +``session`` so controllers, tasks, and tests can control transaction lifetime and avoid hidden Flask-scoped session +usage inside service logic. +""" + import base64 import json import logging @@ -14,7 +21,6 @@ from werkzeug.exceptions import Unauthorized from configs import dify_config from constants.languages import get_valid_language, language_timezone_mapping -from core.db.session_factory import session_factory from events.tenant_event import tenant_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback @@ -236,7 +242,7 @@ class AccountService: ) @staticmethod - def _refresh_account_last_active(account: Account) -> None: + def _refresh_account_last_active(account: Account, session: scoped_session | Session) -> None: now = naive_utc_now() refresh_before = now - ACCOUNT_LAST_ACTIVE_REFRESH_INTERVAL @@ -246,12 +252,12 @@ class AccountService: if not AccountService._should_refresh_account_last_active(account.id): return - db.session.execute( + session.execute( update(Account) .where(Account.id == account.id, Account.last_active_at < refresh_before) .values(last_active_at=now, updated_at=func.current_timestamp()) ) - db.session.commit() + session.commit() @staticmethod def _store_refresh_token(refresh_token: str, account_id: str): @@ -296,20 +302,20 @@ class AccountService: side-effects (current-tenant assignment, commit) are unwanted. ``session`` is injected by the caller so this service stays free - of the Flask-scoped ``db.session`` import. + of a Flask-scoped session import. """ return session.get(Account, account_id) @staticmethod - def load_user(user_id: str) -> None | Account: - account = db.session.get(Account, user_id) + def load_user(user_id: str, session: scoped_session | Session) -> None | Account: + account = session.get(Account, user_id) if not account: return None if account.status == AccountStatus.BANNED: raise Unauthorized("Account is banned.") - current_tenant = db.session.scalar( + current_tenant = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True) .limit(1) @@ -317,7 +323,7 @@ class AccountService: if current_tenant: account.set_tenant_id(current_tenant.tenant_id) else: - available_ta = db.session.scalar( + available_ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.account_id == account.id) .order_by(TenantAccountJoin.id.asc()) @@ -329,13 +335,13 @@ class AccountService: account.set_tenant_id(available_ta.tenant_id) available_ta.current = True available_ta.last_opened_at = naive_utc_now() - db.session.commit() + session.commit() - AccountService._refresh_account_last_active(account) + AccountService._refresh_account_last_active(account, session) # NOTE: make sure account is accessible outside of a db session # This ensures that it will work correctly after upgrading to Flask version 3.1.2 - db.session.refresh(account) - db.session.close() + session.refresh(account) + session.close() return account @staticmethod @@ -353,10 +359,12 @@ class AccountService: return token @staticmethod - def authenticate(email: str, password: str, invite_token: str | None = None) -> Account: + def authenticate( + email: str, password: str, invite_token: str | None = None, *, session: scoped_session | Session + ) -> Account: """authenticate account with email and password""" - account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) + account = session.scalar(select(Account).where(Account.email == email).limit(1)) if not account: raise AccountPasswordError("Invalid email or password.") @@ -379,12 +387,14 @@ class AccountService: account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() - db.session.commit() + session.commit() return account @staticmethod - def update_account_password(account, password, new_password): + def update_account_password( + account: Account, password: str, new_password: str, *, session: scoped_session | Session + ): """update account password""" if account.password and not compare_password(password, account.password, account.password_salt): raise CurrentPasswordIncorrectError("Current password is incorrect.") @@ -401,8 +411,8 @@ class AccountService: base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - db.session.add(account) - db.session.commit() + session.add(account) + session.commit() return account @staticmethod @@ -414,6 +424,8 @@ class AccountService: interface_theme: str = "light", is_setup: bool | None = False, timezone: str | None = None, + *, + session: scoped_session | Session, ) -> Account: """Create an account, preferring explicit user timezone over language-derived defaults.""" if not FeatureService.get_system_features().is_allow_register and not is_setup: @@ -459,13 +471,19 @@ class AccountService: timezone=resolved_timezone, ) - db.session.add(account) - db.session.commit() + session.add(account) + session.commit() return account @staticmethod def create_account_and_tenant( - email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None + email: str, + name: str, + interface_language: str, + password: str | None = None, + timezone: str | None = None, + *, + session: scoped_session | Session, ) -> Account: """Create an account and owner workspace.""" account = AccountService.create_account( @@ -474,10 +492,11 @@ class AccountService: interface_language=interface_language, password=password, timezone=timezone, + session=session, ) try: - TenantService.create_owner_tenant_if_not_exist(account=account) + TenantService.create_owner_tenant_if_not_exist(account=account, session=session) except Exception: # Enterprise-only side-effect should run independently from personal workspace creation. _try_join_enterprise_default_workspace(str(account.id)) @@ -537,11 +556,11 @@ class AccountService: delete_account_task.delay(account.id) @staticmethod - def link_account_integrate(provider: str, open_id: str, account: Account): + def link_account_integrate(provider: str, open_id: str, account: Account, *, session: scoped_session | Session): """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: AccountIntegrate | None = db.session.scalar( + account_integrate: AccountIntegrate | None = session.scalar( select(AccountIntegrate) .where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider) .limit(1) @@ -557,62 +576,62 @@ class AccountService: account_integrate = AccountIntegrate( account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" ) - db.session.add(account_integrate) + session.add(account_integrate) - db.session.commit() + session.commit() logger.info("Account %s linked %s account %s.", account.id, provider, open_id) except Exception as e: logger.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id) raise LinkAccountIntegrateError("Failed to link account.") from e @staticmethod - def close_account(account: Account): + def close_account(account: Account, *, session: scoped_session | Session): """Close account""" account.status = AccountStatus.CLOSED - db.session.commit() + session.commit() @staticmethod - def update_account(account, **kwargs): + def update_account(account: Account, *, session: scoped_session | Session, **kwargs): """Update account fields""" - account = db.session.merge(account) + account = session.merge(account) for field, value in kwargs.items(): if hasattr(account, field): setattr(account, field, value) else: raise AttributeError(f"Invalid field: {field}") - db.session.commit() + session.commit() return account @staticmethod - def update_account_email(account: Account, email: str) -> Account: + def update_account_email(account: Account, email: str, session: scoped_session | Session) -> Account: """Update account email""" account.email = email - account_integrate = db.session.scalar( + account_integrate = session.scalar( select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1) ) if account_integrate: - db.session.delete(account_integrate) - db.session.add(account) - db.session.commit() + session.delete(account_integrate) + session.add(account) + session.commit() return account @staticmethod - def update_login_info(account: Account, *, ip_address: str): + def update_login_info(account: Account, session: scoped_session | Session, *, ip_address: str): """Update last login time and ip""" account.last_login_at = naive_utc_now() account.last_login_ip = ip_address - db.session.add(account) - db.session.commit() + session.add(account) + session.commit() @staticmethod - def login(account: Account, *, ip_address: str | None = None) -> TokenPair: + def login(account: Account, *, session: scoped_session | Session, ip_address: str | None = None) -> TokenPair: if ip_address: - AccountService.update_login_info(account=account, ip_address=ip_address) + AccountService.update_login_info(account=account, session=session, ip_address=ip_address) if account.status == AccountStatus.PENDING: account.status = AccountStatus.ACTIVE - db.session.commit() + session.commit() access_token = AccountService.get_account_jwt_token(account=account) refresh_token = _generate_refresh_token() @@ -629,13 +648,13 @@ class AccountService: AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) @staticmethod - def refresh_token(refresh_token: str) -> TokenPair: + def refresh_token(refresh_token: str, *, session: scoped_session | Session) -> TokenPair: # Verify the refresh token account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) if not account_id: raise ValueError("Invalid refresh token") - account = AccountService.load_user(account_id.decode("utf-8")) + account = AccountService.load_user(account_id.decode("utf-8"), session) if not account: raise ValueError("Invalid account") @@ -650,8 +669,8 @@ class AccountService: return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token) @staticmethod - def load_logged_in_account(*, account_id: str): - return AccountService.load_user(account_id) + def load_logged_in_account(*, account_id: str, session: scoped_session | Session): + return AccountService.load_user(account_id, session) @classmethod def send_reset_password_email( @@ -981,19 +1000,18 @@ class AccountService: return token @staticmethod - def get_account_by_email_with_case_fallback(email: str) -> Account | None: + def get_account_by_email_with_case_fallback(session: Session | scoped_session, email: str) -> Account | None: """ Retrieve an account by email and fall back to the lowercase email if the original lookup fails. This keeps backward compatibility for older records that stored uppercase emails while the rest of the system gradually normalizes new inputs. """ - with session_factory.create_session() as session: - account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() - if account or email == email.lower(): - return account + account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none() + if account or email == email.lower(): + return account - return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none() + return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none() @classmethod def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None: @@ -1004,7 +1022,7 @@ class AccountService: TokenManager.revoke_token(token, "email_code_login") @classmethod - def get_user_through_email(cls, email: str): + def get_user_through_email(cls, email: str, *, session: scoped_session | Session): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email): raise AccountRegisterError( description=( @@ -1013,7 +1031,7 @@ class AccountService: ) ) - account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) + account = session.scalar(select(Account).where(Account.email == email).limit(1)) if not account: return None @@ -1212,13 +1230,19 @@ class AccountService: return False @staticmethod - def check_email_unique(email: str) -> bool: - return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None + def check_email_unique(email: str, *, session: scoped_session | Session) -> bool: + return session.scalar(select(Account).where(Account.email == email).limit(1)) is None class TenantService: @staticmethod - def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant: + def create_tenant( + name: str, + is_setup: bool | None = False, + is_from_dashboard: bool | None = False, + *, + session: scoped_session | Session, + ) -> Tenant: """Create tenant""" if ( not FeatureService.get_system_features().is_allow_create_workspace @@ -1230,8 +1254,8 @@ class TenantService: raise NotAllowedCreateWorkspace() tenant = Tenant(name=name) - db.session.add(tenant) - db.session.commit() + session.add(tenant) + session.commit() for category in TenantPluginAutoUpgradeStrategy.PluginCategory: plugin_upgrade_strategy = TenantPluginAutoUpgradeStrategy( @@ -1243,11 +1267,11 @@ class TenantService: exclude_plugins=[], include_plugins=[], ) - db.session.add(plugin_upgrade_strategy) - db.session.commit() + session.add(plugin_upgrade_strategy) + session.commit() tenant.encrypt_public_key = generate_key_pair(tenant.id) - db.session.commit() + session.commit() from services.credit_pool_service import CreditPoolService @@ -1256,9 +1280,11 @@ class TenantService: return tenant @staticmethod - def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False): + def create_owner_tenant_if_not_exist( + account: Account, name: str | None = None, is_setup: bool | None = False, *, session: scoped_session | Session + ): """Check if user have a workspace or not""" - available_ta = db.session.scalar( + available_ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.account_id == account.id) .order_by(TenantAccountJoin.id.asc()) @@ -1277,10 +1303,10 @@ class TenantService: raise WorkspacesLimitExceededError() if name: - tenant = TenantService.create_tenant(name=name, is_setup=is_setup) + tenant = TenantService.create_tenant(name=name, is_setup=is_setup, session=session) else: - tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup) - TenantService.create_tenant_member(tenant, account, db.session, role="owner") + tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup, session=session) + TenantService.create_tenant_member(tenant, account, session, role="owner") if dify_config.RBAC_ENABLED: owner_role_id = AccountService._resolve_legacy_role_id(str(tenant.id), account.id, TenantAccountRole.OWNER) RBACService.MemberRoles.replace( @@ -1290,16 +1316,16 @@ class TenantService: role_ids=[owner_role_id], ) account.current_tenant = tenant - db.session.commit() + session.commit() tenant_was_created.send(tenant) @staticmethod def create_tenant_member( - tenant: Tenant, account: Account, session: scoped_session, role: str = "normal" + tenant: Tenant, account: Account, session: scoped_session | Session, role: str = "normal" ) -> TenantAccountJoin: """Create tenant member""" if role == TenantAccountRole.OWNER: - if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]): + if TenantService.has_roles(tenant, [TenantAccountRole.OWNER], session=session): logger.error("Tenant %s has already an owner.", tenant.id) raise Exception("Tenant already has an owner.") @@ -1320,10 +1346,10 @@ class TenantService: return ta @staticmethod - def get_join_tenants(account: Account) -> list[Tenant]: + def get_join_tenants(account: Account, *, session: scoped_session | Session) -> list[Tenant]: """Get account join tenants""" return list( - db.session.scalars( + session.scalars( select(Tenant) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) @@ -1342,7 +1368,7 @@ class TenantService: membership + pick the default workspace. ``session`` is injected by the caller so this service stays free - of the Flask-scoped ``db.session`` import. + of a Flask-scoped session import. No tenant-status filter: parity with the legacy controller query (the openapi identity endpoint listed all joined tenants). @@ -1415,7 +1441,7 @@ class TenantService: bearers (no account) collapse to the non-member path. Mirrors the session-injection style of :meth:`account_belongs_to_tenant` rather than :meth:`get_user_role`, which loads full ``Account``/``Tenant`` - objects against the Flask-scoped ``db.session``. + objects against the Flask-scoped session. """ if not account_id: return None @@ -1481,13 +1507,13 @@ class TenantService: ).first() @staticmethod - def get_current_tenant_by_account(account: Account): + def get_current_tenant_by_account(account: Account, *, session: scoped_session | Session): """Get tenant by account and add the role""" tenant = account.current_tenant if not tenant: raise TenantNotFoundError("Tenant not found.") - ta = db.session.scalar( + ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .limit(1) @@ -1499,14 +1525,14 @@ class TenantService: return tenant @staticmethod - def switch_tenant(account: Account, tenant_id: str | None = None): + def switch_tenant(account: Account, tenant_id: str | None = None, *, session: scoped_session | Session): """Switch the current workspace for the account""" # Ensure tenant_id is provided if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = db.session.scalar( + tenant_account_join = session.scalar( select(TenantAccountJoin) .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) .where( @@ -1520,7 +1546,7 @@ class TenantService: if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - db.session.execute( + session.execute( update(TenantAccountJoin) .where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id) .values(current=False) @@ -1529,10 +1555,10 @@ class TenantService: tenant_account_join.last_opened_at = naive_utc_now() # Set the current tenant for the account account.set_tenant_id(tenant_account_join.tenant_id) - db.session.commit() + session.commit() @staticmethod - def get_tenant_members(tenant: Tenant) -> list[Account]: + def get_tenant_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]: """Get tenant members""" stmt = ( select(Account, TenantAccountJoin.role) @@ -1544,14 +1570,14 @@ class TenantService: # Initialize an empty list to store the updated accounts updated_accounts = [] - for account, role in db.session.execute(stmt): + for account, role in session.execute(stmt): account.role = role updated_accounts.append(account) return updated_accounts @staticmethod - def get_dataset_operator_members(tenant: Tenant) -> list[Account]: + def get_dataset_operator_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]: """Get dataset admin members""" stmt = ( select(Account, TenantAccountJoin.role) @@ -1564,20 +1590,20 @@ class TenantService: # Initialize an empty list to store the updated accounts updated_accounts = [] - for account, role in db.session.execute(stmt): + for account, role in session.execute(stmt): account.role = role updated_accounts.append(account) return updated_accounts @staticmethod - def has_roles(tenant: Tenant, roles: list[TenantAccountRole]) -> bool: + def has_roles(tenant: Tenant, roles: list[TenantAccountRole], *, session: scoped_session | Session) -> bool: """Check if user has any of the given roles for a tenant""" if not all(isinstance(role, TenantAccountRole) for role in roles): raise ValueError("all roles must be TenantAccountRole") return ( - db.session.scalar( + session.scalar( select(TenantAccountJoin) .where( TenantAccountJoin.tenant_id == tenant.id, @@ -1589,9 +1615,11 @@ class TenantService: ) @staticmethod - def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None: + def get_user_role( + account: Account, tenant: Tenant, *, session: scoped_session | Session + ) -> TenantAccountRole | None: """Get the role of the current account for a given tenant""" - join = db.session.scalar( + join = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .limit(1) @@ -1599,12 +1627,14 @@ class TenantService: return TenantAccountRole(join.role) if join else None @staticmethod - def get_tenant_count() -> int: + def get_tenant_count(*, session: scoped_session | Session) -> int: """Get tenant count""" - return cast(int, db.session.scalar(select(func.count(Tenant.id)))) + return cast(int, session.scalar(select(func.count(Tenant.id)))) @staticmethod - def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str): + def check_member_permission( + tenant: Tenant, operator: Account, member: Account | None, action: str, *, session: scoped_session | Session + ): """Check member permission""" if action not in {"add", "remove", "update"}: raise InvalidActionError("Invalid action.") @@ -1638,7 +1668,7 @@ class TenantService: "update": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], } - ta_operator = db.session.scalar( + ta_operator = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id) .limit(1) @@ -1648,7 +1678,7 @@ class TenantService: raise NoPermissionError(f"No permission to {action} member.") if action == "remove" and ta_operator.role == TenantAccountRole.ADMIN and member: - ta_member = db.session.scalar( + ta_member = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id) .limit(1) @@ -1657,7 +1687,9 @@ class TenantService: raise NoPermissionError(f"No permission to {action} member.") @staticmethod - def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account): + def remove_member_from_tenant( + tenant: Tenant, account: Account, operator: Account, *, session: scoped_session | Session + ): """Remove member from tenant. Apps and datasets maintained by the removed member are reassigned to @@ -1669,9 +1701,9 @@ class TenantService: if operator.id == account.id: raise CannotOperateSelfError("Cannot operate self.") - TenantService.check_member_permission(tenant, operator, account, "remove") + TenantService.check_member_permission(tenant, operator, account, "remove", session=session) - ta = db.session.scalar( + ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .limit(1) @@ -1688,7 +1720,7 @@ class TenantService: if dify_config.RBAC_ENABLED: owner_id = AccountService.get_rbac_workspace_owner_account_id(str(tenant.id), str(operator.id)) else: - owner_id = db.session.scalar( + owner_id = session.scalar( select(TenantAccountJoin.account_id) .where( TenantAccountJoin.tenant_id == tenant.id, @@ -1699,7 +1731,7 @@ class TenantService: if owner_id is None: raise ValueError(f"Workspace owner not found for tenant {tenant.id}.") - db.session.execute( + session.execute( update(App) .where( App.tenant_id == tenant.id, @@ -1707,7 +1739,7 @@ class TenantService: ) .values(maintainer=owner_id) ) - db.session.execute( + session.execute( update(Dataset) .where( Dataset.tenant_id == tenant.id, @@ -1715,23 +1747,23 @@ class TenantService: ) .values(maintainer=owner_id) ) - db.session.delete(ta) + session.delete(ta) # Clean up orphaned pending accounts (invited but never activated) should_delete_account = False if account.status == AccountStatus.PENDING: # autoflush flushes ta deletion before this query, so 0 means no remaining joins remaining_joins = ( - db.session.scalar( + session.scalar( select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id) ) or 0 ) if remaining_joins == 0: - db.session.delete(account) + session.delete(account) should_delete_account = True - db.session.commit() + session.commit() if should_delete_account: logger.info( @@ -1756,13 +1788,18 @@ class TenantService: account_id, ) + if dify_config.RBAC_ENABLED: + RBACService.MemberRoles.delete_rbac_bindings(tenant_id=tenant.id, account_id=account_id) + @staticmethod - def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account): + def update_member_role( + tenant: Tenant, member: Account, new_role: str, operator: Account, *, session: scoped_session | Session + ): """Update member role""" - TenantService.check_member_permission(tenant, operator, member, "update") + TenantService.check_member_permission(tenant, operator, member, "update", session=session) new_tenant_role = TenantAccountRole(new_role) - target_member_join = db.session.scalar( + target_member_join = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id) .limit(1) @@ -1771,7 +1808,7 @@ class TenantService: if not target_member_join: raise MemberNotInTenantError("Member not in tenant.") - operator_role = TenantService.get_user_role(operator, tenant) + operator_role = TenantService.get_user_role(operator, tenant, session=session) target_role = TenantAccountRole(target_member_join.role) if operator_role == TenantAccountRole.ADMIN and (TenantAccountRole.OWNER in {target_role, new_tenant_role}): raise NoPermissionError("No permission to update member.") @@ -1781,7 +1818,7 @@ class TenantService: if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = db.session.scalar( + current_owner_join = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") .limit(1) @@ -1817,7 +1854,7 @@ class TenantService: ) else: target_member_join.role = new_tenant_role - db.session.commit() + session.commit() @staticmethod def get_custom_config(tenant_id: str): @@ -1826,13 +1863,13 @@ class TenantService: return tenant.custom_config_dict @staticmethod - def is_owner(account: Account, tenant: Tenant) -> bool: - return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER + def is_owner(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool: + return TenantService.get_user_role(account, tenant, session=session) == TenantAccountRole.OWNER @staticmethod - def is_member(account: Account, tenant: Tenant) -> bool: + def is_member(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool: """Check if the account is a member of the tenant""" - return TenantService.get_user_role(account, tenant) is not None + return TenantService.get_user_role(account, tenant, session=session) is not None class RegisterService: @@ -1841,7 +1878,16 @@ class RegisterService: return f"member_invite:token:{token}" @classmethod - def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None): + def setup( + cls, + email: str, + name: str, + password: str, + ip_address: str, + language: str | None, + *, + session: scoped_session | Session, + ): """ Setup dify @@ -1858,22 +1904,23 @@ class RegisterService: interface_language=get_valid_language(language), password=password, is_setup=True, + session=session, ) account.last_login_ip = ip_address account.initialized_at = naive_utc_now() - TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) + TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True, session=session) dify_setup = DifySetup(version=dify_config.project.version) - db.session.add(dify_setup) - db.session.commit() + session.add(dify_setup) + session.commit() except Exception as e: - db.session.execute(delete(DifySetup)) - db.session.execute(delete(TenantAccountJoin)) - db.session.execute(delete(Account)) - db.session.execute(delete(Tenant)) - db.session.commit() + session.execute(delete(DifySetup)) + session.execute(delete(TenantAccountJoin)) + session.execute(delete(Account)) + session.execute(delete(Tenant)) + session.commit() logger.exception("Setup account failed, email: %s, name: %s", email, name) raise ValueError(f"Setup failed: {e}") @@ -1891,9 +1938,11 @@ class RegisterService: is_setup: bool | None = False, create_workspace_required: bool | None = True, timezone: str | None = None, + *, + session: scoped_session | Session, ) -> Account: """Register account""" - db.session.begin_nested() + session.begin_nested() try: interface_language = get_valid_language(language) account = AccountService.create_account( @@ -1903,12 +1952,13 @@ class RegisterService: password=password, is_setup=is_setup, timezone=timezone, + session=session, ) account.status = status or AccountStatus.ACTIVE account.initialized_at = naive_utc_now() if open_id is not None and provider is not None: - AccountService.link_account_integrate(provider, open_id, account) + AccountService.link_account_integrate(provider, open_id, account, session=session) if ( FeatureService.get_system_features().is_allow_create_workspace @@ -1916,27 +1966,27 @@ class RegisterService: and FeatureService.get_system_features().license.workspaces.is_available() ): try: - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, db.session, role="owner") + tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=session) + TenantService.create_tenant_member(tenant, account, session, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) except Exception: _try_join_enterprise_default_workspace(str(account.id)) raise - db.session.commit() + session.commit() _try_join_enterprise_default_workspace(str(account.id)) except WorkSpaceNotAllowedCreateError: - db.session.rollback() + session.rollback() logger.exception("Register failed") raise AccountRegisterError("Workspace is not allowed to create.") except AccountRegisterError as are: - db.session.rollback() + session.rollback() logger.exception("Register failed") raise are except Exception as e: - db.session.rollback() + session.rollback() logger.exception("Register failed") raise AccountRegisterError(f"Registration failed: {e}") from e @@ -1944,7 +1994,14 @@ class RegisterService: @classmethod def invite_new_member( - cls, tenant: Tenant, email: str, language: str | None, role: str = "normal", inviter: Account | None = None + cls, + tenant: Tenant, + email: str, + language: str | None, + role: str = "normal", + inviter: Account | None = None, + *, + session: scoped_session | Session, ) -> str: if not inviter: raise ValueError("Inviter is required") @@ -1958,11 +2015,11 @@ class RegisterService: check_workspace_member_invite_permission(tenant.id) - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) requires_setup = False if not account: - TenantService.check_member_permission(tenant, inviter, None, "add") + TenantService.check_member_permission(tenant, inviter, None, "add", session=session) name = normalized_email.split("@")[0] account = cls.register( @@ -1971,13 +2028,14 @@ class RegisterService: language=language, status=AccountStatus.PENDING, is_setup=True, + session=session, ) - TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role) - TenantService.switch_tenant(account, tenant.id) + TenantService.create_tenant_member(tenant, account, session, tenant_join_role) + TenantService.switch_tenant(account, tenant.id, session=session) requires_setup = True else: - TenantService.check_member_permission(tenant, inviter, account, "add") - ta = db.session.scalar( + TenantService.check_member_permission(tenant, inviter, account, "add", session=session) + ta = session.scalar( select(TenantAccountJoin) .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) .limit(1) @@ -1985,7 +2043,7 @@ class RegisterService: requires_setup = account.status == AccountStatus.PENDING if not ta and (account.status == AccountStatus.PENDING or dify_config.RBAC_ENABLED): - TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role) + TenantService.create_tenant_member(tenant, account, session, tenant_join_role) # Support resend invitation email when the account is pending status if account.status != AccountStatus.PENDING: @@ -2054,20 +2112,20 @@ class RegisterService: @classmethod def get_invitation_if_token_valid( - cls, workspace_id: str | None, email: str | None, token: str + cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session ) -> InvitationDetailDict | None: invitation_data = cls.get_invitation_by_token(token, workspace_id, email) if not invitation_data: return None - tenant = db.session.scalar( + tenant = session.scalar( select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1) ) if not tenant: return None - account = db.session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1)) + account = session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1)) if not account: return None @@ -2107,13 +2165,13 @@ class RegisterService: @classmethod def get_invitation_with_case_fallback( - cls, workspace_id: str | None, email: str | None, token: str + cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session ) -> InvitationDetailDict | None: - invitation = cls.get_invitation_if_token_valid(workspace_id, email, token) + invitation = cls.get_invitation_if_token_valid(workspace_id, email, token, session=session) if invitation or not email or email == email.lower(): return invitation normalized_email = email.lower() - return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token) + return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token, session=session) def _generate_refresh_token(length: int = 64): diff --git a/api/services/agent/composer_service.py b/api/services/agent/composer_service.py index 0a17c06300f..8c83ee80031 100644 --- a/api/services/agent/composer_service.py +++ b/api/services/agent/composer_service.py @@ -4,15 +4,18 @@ from typing import Any from sqlalchemy import func, or_, select from sqlalchemy.exc import IntegrityError +from sqlalchemy.sql.elements import ColumnElement from extensions.ext_database import db from libs.helper import to_timestamp +from models import Account from models.agent import ( Agent, AgentConfigRevision, AgentConfigRevisionOperation, AgentConfigSnapshot, AgentDriveFile, + AgentIconType, AgentKind, AgentScope, AgentSource, @@ -20,9 +23,7 @@ from models.agent import ( WorkflowAgentBindingType, WorkflowAgentNodeBinding, ) -from models.agent_config_entities import ( - DeclaredOutputConfig, -) +from models.agent_config_entities import DeclaredOutputConfig from models.agent_config_entities import ( effective_declared_outputs as _effective_declared_outputs, ) @@ -32,8 +33,12 @@ from services.agent.composer_validator import ComposerConfigValidator from services.agent.errors import ( AgentNameConflictError, AgentNotFoundError, + AgentVersionConflictError, AgentVersionNotFoundError, + InvalidComposerConfigError, ) +from services.agent.roster_service import AgentRosterService +from services.app_service import AppService, CreateAppParams from services.entities.agent_entities import ( AgentSoulConfig, ComposerCandidatesResponse, @@ -172,6 +177,86 @@ class AgentComposerService: ) return state + @classmethod + def copy_workflow_composer_from_roster( + cls, + *, + tenant_id: str, + app_id: str, + node_id: str, + account_id: str, + source_agent_id: str, + source_snapshot_id: str | None = None, + idempotency_key: str | None = None, + ) -> dict[str, Any]: + workflow = cls._get_draft_workflow(tenant_id=tenant_id, app_id=app_id) + binding = cls._require_binding( + cls._get_workflow_binding(tenant_id=tenant_id, workflow_id=workflow.id, node_id=node_id) + ) + + if binding.binding_type == WorkflowAgentBindingType.INLINE_AGENT and idempotency_key: + agent = cls._get_agent_if_present(tenant_id=tenant_id, agent_id=binding.agent_id) + version = cls._get_version_if_present( + tenant_id=tenant_id, + agent_id=agent.id if agent else None, + version_id=binding.current_snapshot_id, + ) + return cls._serialize_workflow_state(binding=binding, agent=agent, version=version) + + if binding.binding_type != WorkflowAgentBindingType.ROSTER_AGENT: + raise InvalidComposerConfigError("Workflow agent node must be bound to a roster agent.") + if binding.agent_id != source_agent_id: + raise InvalidComposerConfigError("Source agent does not match the current workflow node binding.") + + source_agent = cls._require_agent(tenant_id=tenant_id, agent_id=source_agent_id) + if source_agent.scope != AgentScope.ROSTER or source_agent.status != AgentStatus.ACTIVE: + raise InvalidComposerConfigError("Source agent must be an active roster agent.") + source_version = cls._require_version( + tenant_id=tenant_id, + agent_id=source_agent.id, + version_id=source_agent.active_config_snapshot_id, + ) + if source_snapshot_id and source_snapshot_id != source_version.id: + raise AgentVersionConflictError() + + agent_soul = AgentSoulConfig.model_validate(source_version.config_snapshot_dict) + inline_agent = cls._create_workflow_only_agent( + tenant_id=tenant_id, + app_id=app_id, + workflow_id=workflow.id, + node_id=node_id, + account_id=account_id, + agent_soul=agent_soul, + name=source_agent.name, + description=source_agent.description, + role=source_agent.role, + icon_type=source_agent.icon_type, + icon=source_agent.icon, + icon_background=source_agent.icon_background, + ) + cls._copy_agent_drive_rows( + tenant_id=tenant_id, + source_agent_id=source_agent.id, + target_agent_id=inline_agent.id, + account_id=account_id, + agent_soul=agent_soul, + node_job=WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict), + ) + + binding.binding_type = WorkflowAgentBindingType.INLINE_AGENT + binding.agent_id = inline_agent.id + binding.current_snapshot_id = inline_agent.active_config_snapshot_id + binding.updated_by = account_id + db.session.flush() + db.session.commit() + + version = cls._require_version( + tenant_id=tenant_id, + agent_id=inline_agent.id, + version_id=inline_agent.active_config_snapshot_id, + ) + return cls._serialize_workflow_state(binding=binding, agent=inline_agent, version=version) + @classmethod def load_agent_app_composer(cls, *, tenant_id: str, app_id: str) -> dict[str, Any]: agent = db.session.scalar( @@ -849,6 +934,11 @@ class AgentComposerService: tenant_id=tenant_id, account_id=account_id, name=agent_name, + description=payload.description or "", + role=payload.role or "", + icon_type=payload.icon_type, + icon=payload.icon, + icon_background=payload.icon_background, agent_soul=payload.agent_soul, operation=AgentConfigRevisionOperation.SAVE_NEW_AGENT, version_note=payload.version_note, @@ -894,10 +984,25 @@ class AgentComposerService: tenant_id=tenant_id, account_id=account_id, name=agent_name, + description=payload.description if payload.description is not None else source_agent.description, + role=payload.role if payload.role is not None else source_agent.role, + icon_type=payload.icon_type if payload.icon_type is not None else source_agent.icon_type, + icon=payload.icon if payload.icon is not None else source_agent.icon, + icon_background=payload.icon_background + if payload.icon_background is not None + else source_agent.icon_background, agent_soul=agent_soul, operation=AgentConfigRevisionOperation.SAVE_TO_ROSTER, version_note=payload.version_note, ) + cls._copy_agent_drive_rows( + tenant_id=tenant_id, + source_agent_id=source_agent.id, + target_agent_id=roster_agent.id, + account_id=account_id, + agent_soul=agent_soul, + node_job=payload.node_job or WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict), + ) binding.binding_type = WorkflowAgentBindingType.ROSTER_AGENT binding.agent_id = roster_agent.id binding.current_snapshot_id = roster_agent.active_config_snapshot_id @@ -916,11 +1021,21 @@ class AgentComposerService: node_id: str, account_id: str, agent_soul: AgentSoulConfig, + name: str | None = None, + description: str = "", + role: str = "", + icon_type: Any | None = None, + icon: str | None = None, + icon_background: str | None = None, ) -> Agent: agent = Agent( tenant_id=tenant_id, - name=f"Workflow Agent {node_id}", - description="", + name=name or f"Workflow Agent {node_id}", + description=description, + role=role, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, agent_kind=AgentKind.DIFY_AGENT, scope=AgentScope.WORKFLOW_ONLY, source=AgentSource.WORKFLOW, @@ -945,6 +1060,98 @@ class AgentComposerService: agent.active_config_has_model = agent_soul_has_model(agent_soul) return agent + @classmethod + def _copy_agent_drive_rows( + cls, + *, + tenant_id: str, + source_agent_id: str, + target_agent_id: str, + account_id: str, + agent_soul: AgentSoulConfig, + node_job: WorkflowNodeJobConfig | None = None, + ) -> None: + exact_keys, prefixes = cls._drive_copy_scopes_from_agent_configs(agent_soul=agent_soul, node_job=node_job) + predicates: list[ColumnElement[bool]] = [] + if exact_keys: + predicates.append(AgentDriveFile.key.in_(sorted(exact_keys))) + predicates.extend(AgentDriveFile.key.startswith(prefix) for prefix in sorted(prefixes)) + if not predicates: + return + + source_rows = list( + db.session.scalars( + select(AgentDriveFile).where( + AgentDriveFile.tenant_id == tenant_id, + AgentDriveFile.agent_id == source_agent_id, + or_(*predicates), + ) + ).all() + ) + if not source_rows: + return + + existing_target_keys = set( + db.session.scalars( + select(AgentDriveFile.key).where( + AgentDriveFile.tenant_id == tenant_id, + AgentDriveFile.agent_id == target_agent_id, + AgentDriveFile.key.in_([row.key for row in source_rows]), + ) + ).all() + ) + for row in source_rows: + if row.key in existing_target_keys: + continue + db.session.add( + AgentDriveFile( + tenant_id=tenant_id, + agent_id=target_agent_id, + key=row.key, + file_kind=row.file_kind, + file_id=row.file_id, + value_owned_by_drive=row.value_owned_by_drive, + is_skill=row.is_skill, + skill_metadata=row.skill_metadata, + size=row.size, + hash=row.hash, + mime_type=row.mime_type, + created_by=account_id, + ) + ) + + @staticmethod + def _drive_copy_scopes_from_agent_configs( + *, agent_soul: AgentSoulConfig, node_job: WorkflowNodeJobConfig | None = None + ) -> tuple[set[str], set[str]]: + from services.agent.prompt_mentions import MentionKind, parse_prompt_mentions + from services.agent_drive_service import decode_drive_mention_ref + + exact_keys: set[str] = set() + prefixes: set[str] = set() + + for mention in parse_prompt_mentions(agent_soul.prompt.system_prompt): + if mention.kind not in {MentionKind.SKILL, MentionKind.FILE}: + continue + drive_key = decode_drive_mention_ref(mention.ref_id) + if not drive_key: + continue + if mention.kind == MentionKind.SKILL and "/" in drive_key: + prefixes.add(f"{drive_key.rsplit('/', 1)[0]}/") + else: + exact_keys.add(drive_key) + + if node_job is not None: + for file_ref in node_job.metadata.file_refs or []: + if file_ref.drive_key: + exact_keys.add(file_ref.drive_key) + for output in node_job.declared_outputs: + benchmark_ref = output.check.benchmark_file_ref if output.check and output.check.enabled else None + if benchmark_ref and benchmark_ref.drive_key: + exact_keys.add(benchmark_ref.drive_key) + + return exact_keys, prefixes + @classmethod def _create_roster_agent_for_composer( cls, @@ -955,27 +1162,42 @@ class AgentComposerService: agent_soul: AgentSoulConfig, operation: AgentConfigRevisionOperation, version_note: str | None, + description: str = "", + role: str = "", + icon_type: AgentIconType | None = None, + icon: str | None = None, + icon_background: str | None = None, ) -> Agent: - agent = Agent( - tenant_id=tenant_id, - name=name, - description="", - agent_kind=AgentKind.DIFY_AGENT, - scope=AgentScope.ROSTER, - source=AgentSource.WORKFLOW, - status=AgentStatus.ACTIVE, - created_by=account_id, - updated_by=account_id, - ) - db.session.add(agent) + account = cls._require_account(account_id=account_id) try: - db.session.flush() + app = AppService().create_app( + tenant_id, + CreateAppParams( + name=name, + description=description, + mode="agent", + agent_role=role, + icon_type=icon_type.value if isinstance(icon_type, AgentIconType) else icon_type, + icon=icon, + icon_background=icon_background, + ), + account, + ) except IntegrityError as exc: db.session.rollback() raise AgentNameConflictError() from exc - version = cls._create_config_version( + + agent = AgentRosterService(db.session).get_app_backing_agent(tenant_id=tenant_id, app_id=app.id) + if agent is None: + raise AgentNotFoundError() + + current_snapshot = cls._require_version( tenant_id=tenant_id, agent_id=agent.id, + version_id=agent.active_config_snapshot_id, + ) + version = cls._update_current_version( + current_snapshot=current_snapshot, account_id=account_id, agent_soul=agent_soul, operation=operation, @@ -983,6 +1205,7 @@ class AgentComposerService: ) agent.active_config_snapshot_id = version.id agent.active_config_has_model = agent_soul_has_model(agent_soul) + agent.updated_by = account_id return agent @classmethod @@ -1111,6 +1334,13 @@ class AgentComposerService: raise AgentNotFoundError() return agent + @classmethod + def _require_account(cls, *, account_id: str) -> Account: + account = db.session.get(Account, account_id) + if not account: + raise ValueError("Account not found") + return account + @classmethod def _get_agent_if_present(cls, *, tenant_id: str, agent_id: str | None) -> Agent | None: if not agent_id: diff --git a/api/services/agent/errors.py b/api/services/agent/errors.py index dcc8f69961f..6a1dc6fb628 100644 --- a/api/services/agent/errors.py +++ b/api/services/agent/errors.py @@ -17,6 +17,10 @@ class AgentArchivedError(Conflict): description = "Archived agent cannot be modified." +class AgentVersionConflictError(Conflict): + description = "Agent config version changed. Please reload and try again." + + class AgentSoulLockedError(BadRequest): description = "Agent Soul is locked for this workflow node." diff --git a/api/services/agent/roster_service.py b/api/services/agent/roster_service.py index 6a9d5818647..97d91b50770 100644 --- a/api/services/agent/roster_service.py +++ b/api/services/agent/roster_service.py @@ -837,6 +837,7 @@ class AgentRosterService: if agent.source == AgentSource.AGENT_APP: return { AgentConfigRevisionOperation.SAVE_NEW_VERSION, + AgentConfigRevisionOperation.SAVE_TO_ROSTER, AgentConfigRevisionOperation.RESTORE_VERSION, } return { diff --git a/api/services/agent_app_feature_service.py b/api/services/agent_app_feature_service.py index b8e98653c8e..5fd794bb10f 100644 --- a/api/services/agent_app_feature_service.py +++ b/api/services/agent_app_feature_service.py @@ -69,7 +69,12 @@ class AgentAppFeatureConfigService: @classmethod def update_features( - cls, *, app_model: App, account: Account, config: dict[str, Any], session: scoped_session + cls, + *, + app_model: App, + account: Account, + config: dict[str, Any], + session: scoped_session, ) -> AppModelConfig: """Persist the presentation features as a new app_model_config version. diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index e1762c686ff..4f69c4b44a9 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -4,6 +4,7 @@ from typing import TypedDict import pandas as pd from sqlalchemy import delete, or_, select, update +from sqlalchemy.orm import scoped_session from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -300,17 +301,19 @@ class AppAnnotationService: return annotation @classmethod - def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str): + def update_app_annotation_directly( + cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str, session: scoped_session + ): # get app info _, current_tenant_id = current_account_with_tenant() - app = db.session.scalar( + app = session.scalar( select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation = db.session.get(MessageAnnotation, annotation_id) + annotation = session.get(MessageAnnotation, annotation_id) if not annotation: raise NotFound("Annotation not found") @@ -326,9 +329,9 @@ class AppAnnotationService: annotation.content = answer annotation.question = question - db.session.commit() + session.commit() # if annotation reply is enabled , add annotation to index - app_annotation_setting = db.session.scalar( + app_annotation_setting = session.scalar( select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) @@ -344,33 +347,33 @@ class AppAnnotationService: return annotation @classmethod - def delete_app_annotation(cls, app_id: str, annotation_id: str): + def delete_app_annotation(cls, app_id: str, annotation_id: str, session: scoped_session): # get app info _, current_tenant_id = current_account_with_tenant() - app = db.session.scalar( + app = session.scalar( select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) ) if not app: raise NotFound("App not found") - annotation = db.session.get(MessageAnnotation, annotation_id) + annotation = session.get(MessageAnnotation, annotation_id) if not annotation: raise NotFound("Annotation not found") - db.session.delete(annotation) + session.delete(annotation) - annotation_hit_histories = db.session.scalars( + annotation_hit_histories = session.scalars( select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id) ).all() if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: - db.session.delete(annotation_hit_history) + session.delete(annotation_hit_history) - db.session.commit() + session.commit() # if annotation reply is enabled , delete annotation index - app_annotation_setting = db.session.scalar( + app_annotation_setting = session.scalar( select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a9024eb3bdd..14c5c0111e5 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,11 +5,11 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context +from sqlalchemy.orm import Session, scoped_session from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager -from extensions.ext_database import db from graphon.model_runtime.entities.model_entities import ModelType from models.enums import MessageStatus from models.model import App, AppMode, Message @@ -77,6 +77,8 @@ class AudioService: def transcript_tts( cls, app_model: App, + *, + session: Session | scoped_session, text: str | None = None, voice: str | None = None, end_user: str | None = None, @@ -87,7 +89,7 @@ class AudioService: if voice is None: if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: if is_draft: - workflow = WorkflowService().get_draft_workflow(app_model=app_model) + workflow = WorkflowService().get_draft_workflow(app_model=app_model, session=session) else: workflow = app_model.workflow if ( @@ -132,7 +134,7 @@ class AudioService: uuid.UUID(message_id) except ValueError: return None - message = db.session.get(Message, message_id) + message = session.get(Message, message_id) if message is None: return None if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}: diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 36b15170567..42f1d4d8d40 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -2,17 +2,17 @@ import json from typing import Any from sqlalchemy import select +from sqlalchemy.orm import Session from core.helper import encrypter -from extensions.ext_database import db from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_factory import ApiKeyAuthFactory class ApiKeyAuthService: @staticmethod - def get_provider_auth_list(tenant_id: str): - data_source_api_key_bindings = db.session.scalars( + def get_provider_auth_list(session: Session, tenant_id: str): + data_source_api_key_bindings = session.scalars( select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False) ) @@ -20,7 +20,7 @@ class ApiKeyAuthService: return data_source_api_key_bindings @staticmethod - def create_provider_auth(tenant_id: str, args: dict[str, Any]): + def create_provider_auth(session: Session, tenant_id: str, args: dict[str, Any]): auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() if auth_result: # Encrypt the api key @@ -31,12 +31,12 @@ class ApiKeyAuthService: tenant_id=tenant_id, category=args["category"], provider=args["provider"] ) data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) - db.session.add(data_source_api_key_binding) - db.session.commit() + session.add(data_source_api_key_binding) + session.commit() @staticmethod - def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = db.session.scalar( + def get_auth_credentials(session: Session, tenant_id: str, category: str, provider: str): + data_source_api_key_bindings = session.scalar( select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.category == category, @@ -52,16 +52,16 @@ class ApiKeyAuthService: return credentials @staticmethod - def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = db.session.scalar( + def delete_provider_auth(session: Session, tenant_id: str, binding_id: str): + data_source_api_key_binding = session.scalar( select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id, ) ) if data_source_api_key_binding: - db.session.delete(data_source_api_key_binding) - db.session.commit() + session.delete(data_source_api_key_binding) + session.commit() @classmethod def validate_api_key_auth_args(cls, args): diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 4e441b53af0..96c362b3dfc 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -185,8 +185,8 @@ class EnterpriseRequest(BaseRequest): if ( not cls.rbac_base_url.startswith("http") - or not cls.rbac_base_url.startswith("https") - or not cls.rbac_base_url + and not cls.rbac_base_url.startswith("https") + and not cls.rbac_base_url ): raise ValueError("ENTERPRISE_RBAC_API_URL is required when RBAC_ENABLED=true") diff --git a/api/services/enterprise/rbac_service.py b/api/services/enterprise/rbac_service.py index b5585932b29..1072c7051b0 100644 --- a/api/services/enterprise/rbac_service.py +++ b/api/services/enterprise/rbac_service.py @@ -379,6 +379,9 @@ _LEGACY_WORKSPACE_EDITOR_KEYS: list[str] = [ "snippets.create_and_modify", "tool.manage", "snippets.create_and_modify", + "billing.view", + "billing.subscription.manage", + "billing.manage", ] _LEGACY_WORKSPACE_NORMAL_KEYS: list[str] = [ @@ -386,6 +389,9 @@ _LEGACY_WORKSPACE_NORMAL_KEYS: list[str] = [ "plugin.install", "credential.use", "app_library.access", + "billing.view", + "billing.subscription.manage", + "billing.manage", ] _LEGACY_WORKSPACE_DATASET_OPERATOR_KEYS: list[str] = [ @@ -834,6 +840,7 @@ class RBACService: options: ListOption | None = None, ) -> Paginated[RBACRole]: params = (options or ListOption()).to_params({"include_owner": include_owner}) + params["dataset_operator_enabled"] = dify_config.DATASET_OPERATOR_ENABLED data = _inner_call( "GET", f"{_INNER_PREFIX}/roles", @@ -1678,6 +1685,17 @@ class RBACService: ) return MemberRolesResponse.model_validate(data or {}) + @staticmethod + def delete_rbac_bindings(tenant_id: str, account_id: str): + data = _inner_call( + "DELETE", + f"{_INNER_PREFIX}/members/rbac-bindings", + tenant_id=tenant_id, + account_id=account_id, + params={"account_id": account_id}, + ) + return data + class CheckAccess: """Call the ``/inner/api/rbac/check-access`` endpoint.""" diff --git a/api/services/entities/agent_entities.py b/api/services/entities/agent_entities.py index e7b5cbd7c6d..a8634bceb09 100644 --- a/api/services/entities/agent_entities.py +++ b/api/services/entities/agent_entities.py @@ -42,6 +42,11 @@ class ComposerSavePayload(BaseModel): idempotency_key: str | None = None client_revision_id: str | None = None new_agent_name: str | None = Field(default=None, min_length=1, max_length=255) + description: str | None = None + role: str | None = Field(default=None, max_length=255) + icon_type: AgentIconType | None = None + icon: str | None = Field(default=None, max_length=255) + icon_background: str | None = Field(default=None, max_length=255) @model_validator(mode="after") def validate_variant_sections(self) -> "ComposerSavePayload": @@ -58,6 +63,12 @@ class ComposerSavePayload(BaseModel): return self +class WorkflowComposerCopyFromRosterPayload(BaseModel): + source_agent_id: str = Field(min_length=1, max_length=255) + source_snapshot_id: str | None = Field(default=None, max_length=255) + idempotency_key: str | None = Field(default=None, max_length=255) + + class RosterAgentCreatePayload(BaseModel): name: str = Field(min_length=1, max_length=255) mode: Literal["agent"] = "agent" diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 2ae0c63ff75..c9d86ee4578 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -181,6 +181,7 @@ class SystemFeatureModel(FeatureResponseModel): enable_creators_platform: bool = False enable_trial_app: bool = False enable_explore_banner: bool = False + enable_learn_app: bool = True rbac_enabled: bool = False @@ -282,6 +283,7 @@ class FeatureService: system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER + system_features.enable_learn_app = dify_config.ENABLE_LEARN_APP @classmethod def _fulfill_trial_models_from_env(cls) -> list[str]: diff --git a/api/services/feedback_service.py b/api/services/feedback_service.py index 24cfb8aa852..62885c901b7 100644 --- a/api/services/feedback_service.py +++ b/api/services/feedback_service.py @@ -14,8 +14,9 @@ from models.model import Account, App, Conversation, Message, MessageFeedback class FeedbackService: @staticmethod def export_feedbacks( - session: Session, app_id: str, + *, + session: Session, from_source: str | None = None, rating: str | None = None, has_comment: bool | None = None, @@ -28,6 +29,7 @@ class FeedbackService: Args: app_id: Application ID + session: Database session used to run the export query from_source: Filter by feedback source ('user' or 'admin') rating: Filter by rating ('like' or 'dislike') has_comment: Only include feedback with comments diff --git a/api/services/knowledge_retrieval_inner_service.py b/api/services/knowledge_retrieval_inner_service.py index fccc81c4a29..8759413f533 100644 --- a/api/services/knowledge_retrieval_inner_service.py +++ b/api/services/knowledge_retrieval_inner_service.py @@ -13,11 +13,11 @@ of a separate validation error. """ from sqlalchemy import select +from sqlalchemy.orm import scoped_session from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest -from extensions.ext_database import db from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.nodes.llm.entities import ModelConfig from models.dataset import Dataset @@ -38,7 +38,11 @@ from services.errors.knowledge_retrieval import ( class InnerKnowledgeRetrievalService: """Validate inner caller scope and delegate to workflow dataset retrieval.""" - def retrieve(self, request: InnerKnowledgeRetrieveRequest) -> InnerKnowledgeRetrieveResponse: + def retrieve( + self, + request: InnerKnowledgeRetrieveRequest, + session: scoped_session, + ) -> InnerKnowledgeRetrieveResponse: """Run tenant-scoped retrieval for a trusted internal caller. This method only rejects caller app existence/tenant mismatches and @@ -56,8 +60,8 @@ class InnerKnowledgeRetrievalService: InnerKnowledgeRetrieveDatasetTenantMismatchError: At least one requested dataset is outside the caller tenant. """ - self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id) - self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids) + self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id, session=session) + self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids, session=session) rag = DatasetRetrieval() results = rag.knowledge_retrieval(request=self._to_rag_request(request)) @@ -66,8 +70,8 @@ class InnerKnowledgeRetrievalService: usage=InnerKnowledgeRetrieveUsage.model_validate(jsonable_encoder(rag.llm_usage)), ) - def _validate_caller_app(self, *, tenant_id: str, app_id: str) -> None: - app = db.session.scalar(select(App).where(App.id == app_id).limit(1)) + def _validate_caller_app(self, *, tenant_id: str, app_id: str, session: scoped_session) -> None: + app = session.scalar(select(App).where(App.id == app_id).limit(1)) if app is None: raise InnerKnowledgeRetrieveAppNotFoundError(f"App '{app_id}' not found") if app.tenant_id != tenant_id: @@ -75,8 +79,8 @@ class InnerKnowledgeRetrievalService: f"App '{app_id}' does not belong to tenant '{tenant_id}'" ) - def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str]) -> None: - datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() + def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str], session: scoped_session) -> None: + datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() found_ids = {dataset.id for dataset in datasets} missing_ids = sorted(set(dataset_ids) - found_ids) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index f9dcfd25c7f..d9cd65b2b39 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -2,9 +2,9 @@ import copy import logging from sqlalchemy import delete, func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from libs.login import resolve_account_fallback @@ -23,6 +23,7 @@ logger = logging.getLogger(__name__) class MetadataService: @staticmethod def create_metadata( + session: Session, dataset_id: str, metadata_args: MetadataArgs, current_user: Account | None = None, # TODO: the service_api is not migrated yet @@ -33,7 +34,7 @@ class MetadataService: raise ValueError("Metadata name cannot exceed 255 characters.") current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id) # check if metadata name already exists - if db.session.scalar( + if session.scalar( select(DatasetMetadata) .where( DatasetMetadata.tenant_id == current_tenant_id, @@ -53,12 +54,13 @@ class MetadataService: name=metadata_args.name, created_by=current_user.id, ) - db.session.add(metadata) - db.session.commit() + session.add(metadata) + session.commit() return metadata @staticmethod def update_metadata_name( + session: Session, dataset_id: str, metadata_id: str, name: str, @@ -72,7 +74,7 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id) - if db.session.scalar( + if session.scalar( select(DatasetMetadata) .where( DatasetMetadata.tenant_id == current_tenant_id, @@ -87,7 +89,7 @@ class MetadataService: raise ValueError("Metadata name already exists in Built-in fields.") try: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - metadata = db.session.scalar( + metadata = session.scalar( select(DatasetMetadata) .where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id) .limit(1) @@ -100,7 +102,7 @@ class MetadataService: metadata.updated_at = naive_utc_now() # update related documents - dataset_metadata_bindings = db.session.scalars( + dataset_metadata_bindings = session.scalars( select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id) ).all() if dataset_metadata_bindings: @@ -114,8 +116,8 @@ class MetadataService: value = doc_metadata.pop(old_name, None) doc_metadata[name] = value document.doc_metadata = doc_metadata - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() return metadata except Exception: logger.exception("Update metadata name failed") @@ -124,21 +126,21 @@ class MetadataService: redis_client.delete(lock_key) @staticmethod - def delete_metadata(dataset_id: str, metadata_id: str): + def delete_metadata(session: Session, dataset_id: str, metadata_id: str): lock_key = f"dataset_metadata_lock_{dataset_id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - metadata = db.session.scalar( + metadata = session.scalar( select(DatasetMetadata) .where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id) .limit(1) ) if metadata is None: raise ValueError("Metadata not found.") - db.session.delete(metadata) + session.delete(metadata) # deal related documents - dataset_metadata_bindings = db.session.scalars( + dataset_metadata_bindings = session.scalars( select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id) ).all() if dataset_metadata_bindings: @@ -151,8 +153,8 @@ class MetadataService: doc_metadata = copy.deepcopy(document.doc_metadata) doc_metadata.pop(metadata.name, None) document.doc_metadata = doc_metadata - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() return metadata except Exception: logger.exception("Delete metadata failed") @@ -170,13 +172,13 @@ class MetadataService: ] @staticmethod - def enable_built_in_field(dataset: Dataset): + def enable_built_in_field(session: Session, dataset: Dataset): if dataset.built_in_field_enabled: return lock_key = f"dataset_metadata_lock_{dataset.id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) - db.session.add(dataset) + session.add(dataset) documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) if documents: for document in documents: @@ -190,22 +192,22 @@ class MetadataService: doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata - db.session.add(document) + session.add(document) dataset.built_in_field_enabled = True - db.session.commit() + session.commit() except Exception: logger.exception("Enable built-in field failed") finally: redis_client.delete(lock_key) @staticmethod - def disable_built_in_field(dataset: Dataset): + def disable_built_in_field(session: Session, dataset: Dataset): if not dataset.built_in_field_enabled: return lock_key = f"dataset_metadata_lock_{dataset.id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) - db.session.add(dataset) + session.add(dataset) documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) document_ids = [] if documents: @@ -220,10 +222,10 @@ class MetadataService: doc_metadata.pop(BuiltInField.last_update_date, None) doc_metadata.pop(BuiltInField.source, None) document.doc_metadata = doc_metadata - db.session.add(document) + session.add(document) document_ids.append(document.id) dataset.built_in_field_enabled = False - db.session.commit() + session.commit() except Exception: logger.exception("Disable built-in field failed") finally: @@ -231,6 +233,7 @@ class MetadataService: @staticmethod def update_documents_metadata( + session: Session, dataset: Dataset, metadata_args: MetadataOperationData, current_user: Account | None = None, # TODO: the service_api is not migrated yet @@ -259,11 +262,11 @@ class MetadataService: doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp() doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type] document.doc_metadata = doc_metadata - db.session.add(document) + session.add(document) # deal metadata binding (in the same transaction as the doc_metadata update) if not operation.partial_update: - db.session.execute( + session.execute( delete(DatasetMetadataBinding).where( DatasetMetadataBinding.document_id == operation.document_id ) @@ -272,7 +275,7 @@ class MetadataService: for metadata_value in operation.metadata_list: # check if binding already exists if operation.partial_update: - existing_binding = db.session.scalar( + existing_binding = session.scalar( select(DatasetMetadataBinding) .where( DatasetMetadataBinding.document_id == operation.document_id, @@ -290,10 +293,10 @@ class MetadataService: metadata_id=metadata_value.id, created_by=current_user.id, ) - db.session.add(dataset_metadata_binding) - db.session.commit() + session.add(dataset_metadata_binding) + session.commit() except Exception: - db.session.rollback() + session.rollback() logger.exception("Update documents metadata failed") raise finally: @@ -313,14 +316,14 @@ class MetadataService: redis_client.set(lock_key, 1, ex=3600) @staticmethod - def get_dataset_metadatas(dataset: Dataset): + def get_dataset_metadatas(session: Session, dataset: Dataset): return { "doc_metadata": [ { "id": item.get("id"), "name": item.get("name"), "type": item.get("type"), - "count": db.session.scalar( + "count": session.scalar( select(func.count(DatasetMetadataBinding.id)).where( DatasetMetadataBinding.metadata_id == item.get("id"), DatasetMetadataBinding.dataset_id == dataset.id, diff --git a/api/services/oauth_server.py b/api/services/oauth_server.py index 22648070f01..5f3277c9525 100644 --- a/api/services/oauth_server.py +++ b/api/services/oauth_server.py @@ -91,4 +91,4 @@ class OAuthServerService: user_id_str = user_account_id.decode("utf-8") - return AccountService.load_user(user_id_str) + return AccountService.load_user(user_id_str, db.session) diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py index e48286303cb..03b72a4f57c 100644 --- a/api/services/recommend_app/buildin/buildin_retrieval.py +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -5,6 +5,7 @@ from typing import Any, override from flask import current_app +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase from services.recommend_app.recommend_app_type import RecommendAppType @@ -25,6 +26,11 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): result = self.fetch_recommended_apps_from_builtin(language) return result + @override + def get_learn_dify_apps(self, language: str): + result = DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db(language) + return result + @override def get_recommend_app_detail(self, app_id: str): result = self.fetch_recommended_app_detail_from_builtin(app_id) diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index 9d6c28c2117..f6786175896 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -49,6 +49,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): result = self.fetch_recommended_apps_from_db(language) return result + @override + def get_learn_dify_apps(self, language: str) -> RecommendedAppsResultDict: + result = self.fetch_learn_dify_apps_from_db(language) + return result + @override def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None: result = self.fetch_recommended_app_detail_from_db(app_id) diff --git a/api/services/recommend_app/recommend_app_base.py b/api/services/recommend_app/recommend_app_base.py index 4214d56e4aa..f819cc3a937 100644 --- a/api/services/recommend_app/recommend_app_base.py +++ b/api/services/recommend_app/recommend_app_base.py @@ -6,6 +6,8 @@ class RecommendAppRetrievalBase(Protocol): def get_recommended_apps_and_categories(self, language: str) -> Any: ... + def get_learn_dify_apps(self, language: str) -> Any: ... + def get_recommend_app_detail(self, app_id: str) -> Any: ... def get_type(self) -> str: ... diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py index 890fb132faa..2e3222bb978 100644 --- a/api/services/recommend_app/remote/remote_retrieval.py +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -2,15 +2,28 @@ import logging from typing import Any, override import httpx +from flask import has_request_context, request from configs import dify_config from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase from services.recommend_app.recommend_app_type import RecommendAppType logger = logging.getLogger(__name__) +def _current_origin_headers() -> dict[str, str]: + origin = request.headers.get("Origin") if has_request_context() else None + if origin: + return {"Origin": origin} + + console_web_url = getattr(dify_config, "CONSOLE_WEB_URL", "") + if not isinstance(console_web_url, str) or not console_web_url: + return {} + return {"Origin": console_web_url} + + class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): """ Retrieval recommended app from dify official. @@ -37,6 +50,15 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language) return result + @override + def get_learn_dify_apps(self, language: str): + try: + result = self.fetch_learn_dify_apps_from_dify_official(language) + except Exception as e: + logger.warning("fetch learn dify apps from dify official failed: %s, switch to database.", e) + result = DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db(language) + return result + @override def get_type(self) -> str: return RecommendAppType.REMOTE @@ -50,7 +72,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/apps/{app_id}" - response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) + response = httpx.get(url, headers=_current_origin_headers(), timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: return None data: dict[str, Any] = response.json() @@ -65,9 +87,25 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/apps?language={language}" - response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) + response = httpx.get(url, headers=_current_origin_headers(), timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") result: dict[str, Any] = response.json() return result + + @classmethod + def fetch_learn_dify_apps_from_dify_official(cls, language: str): + """ + Fetch Learn Dify apps from dify official. + :param language: language + :return: + """ + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/apps/learn-dify?language={language}" + response = httpx.get(url, headers=_current_origin_headers(), timeout=httpx.Timeout(10.0, connect=3.0)) + if response.status_code != 200: + raise ValueError(f"fetch learn dify apps failed, status code: {response.status_code}") + + result: dict[str, Any] = response.json() + return result diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index bc8bb58acba..2d247ba5b71 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -6,7 +6,6 @@ from sqlalchemy.orm import scoped_session from configs import dify_config from models.model import AccountTrialAppRecord, TrialApp from services.feature_service import FeatureService -from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory @@ -38,11 +37,13 @@ class RecommendedAppService: @classmethod def get_learn_dify_apps(cls, session: scoped_session, language: str) -> dict[str, Any]: """ - Get database-backed recommended apps marked as Learn Dify. + Get recommended apps marked for the Learn Dify section. :param language: language :return: """ - result = DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db(language) + mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE + retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() + result = retrieval_instance.get_learn_dify_apps(language) if FeatureService.get_system_features().enable_trial_app: for app in result["recommended_apps"]: diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py index 2b63d9171e9..6ecc8eb8bc9 100644 --- a/api/services/webapp_auth_service.py +++ b/api/services/webapp_auth_service.py @@ -35,7 +35,7 @@ class WebAppAuthService: @staticmethod def authenticate(email: str, password: str) -> Account: """authenticate account with email and password""" - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if not account: raise AccountNotFoundError() @@ -55,7 +55,7 @@ class WebAppAuthService: @classmethod def get_user_through_email(cls, email: str): - account = AccountService.get_account_by_email_with_case_fallback(email) + account = AccountService.get_account_by_email_with_case_fallback(db.session, email) if not account: return None diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 9f8e4b83093..262ccc18f83 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast from sqlalchemy import exists, select -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session, scoped_session, sessionmaker from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -142,7 +142,7 @@ class WorkflowService: return db.session.execute(stmt).scalar_one() def get_draft_workflow( - self, app_model: App, workflow_id: str | None = None, session: Session | None = None + self, app_model: App, workflow_id: str | None = None, session: Session | scoped_session | None = None ) -> Workflow | None: """ Get draft workflow @@ -169,7 +169,7 @@ class WorkflowService: return workflow def get_published_workflow_by_id( - self, app_model: App, workflow_id: str, session: Session | None = None + self, app_model: App, workflow_id: str, session: Session | scoped_session | None = None ) -> Workflow | None: """ fetch published workflow by workflow_id diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 70114a83f0b..180c077b88a 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -36,7 +36,9 @@ class WorkspaceService: feature = FeatureService.get_features(tenant.id, exclude_vector_space=True) can_replace_logo = feature.can_replace_logo - if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]): + if can_replace_logo and TenantService.has_roles( + tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], session=db.session + ): base_url = dify_config.FILES_URL replace_webapp_logo = ( f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py index 70988eb0a13..ea875e63fe8 100644 --- a/api/tests/integration_tests/conftest.py +++ b/api/tests/integration_tests/conftest.py @@ -84,6 +84,7 @@ def setup_account(request) -> Generator[Account, None, None]: password=secrets.token_hex(16), ip_address="localhost", language="en-US", + session=db.session, ) with _CACHED_APP.test_request_context(): diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py index 5eb9f71e695..e55b46d38bf 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -1,7 +1,7 @@ """Controller integration tests for API key data source auth routes.""" import json -from unittest.mock import patch +from unittest.mock import ANY, patch from flask.testing import FlaskClient from sqlalchemy import select @@ -85,7 +85,7 @@ def test_create_binding_successful( assert response.status_code == 200 assert response.get_json() == {"result": "success"} - create_auth.assert_called_once_with(tenant_id, payload) + create_auth.assert_called_once_with(ANY, tenant_id, payload) def test_create_binding_failure( diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index bb7921a5f45..109332e16c9 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -270,10 +270,7 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - with patch("services.account_service.session_factory") as mock_factory: - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 014c1588fee..812aa299c1b 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -165,10 +165,7 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - with patch("services.account_service.session_factory") as mock_factory: - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index d043c0d413a..464e0134a2f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -494,10 +494,7 @@ class TestAccountGeneration: second_result.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first_result, second_result] - with patch("services.account_service.session_factory") as mock_factory: - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Case@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 @@ -551,6 +548,7 @@ class TestAccountGeneration: provider="github", language="en-US", timezone=None, + session=ANY, ) else: mock_register_service.register.assert_not_called() @@ -584,6 +582,7 @@ class TestAccountGeneration: provider="github", language="en-US", timezone=None, + session=ANY, ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) @@ -615,6 +614,7 @@ class TestAccountGeneration: provider="github", language="zh-Hans", timezone="Asia/Shanghai", + session=ANY, ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None) @@ -646,6 +646,7 @@ class TestAccountGeneration: provider="github", language="zh-Hans", timezone=None, + session=ANY, ) @patch("controllers.console.auth.oauth._get_account_by_openid_or_email") @@ -676,7 +677,7 @@ class TestAccountGeneration: assert result == mock_account assert oauth_new_user is False - mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace") + mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace", session=ANY) mock_tenant_service.create_tenant_member.assert_called_once_with( mock_new_tenant, mock_account, ANY, role="owner" ) diff --git a/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py b/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py index 5fe0f787524..00f605b74c5 100644 --- a/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py +++ b/api/tests/test_containers_integration_tests/controllers/openapi/conftest.py @@ -1,7 +1,7 @@ from __future__ import annotations import uuid -from collections.abc import Callable, Iterator +from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Literal from unittest.mock import patch @@ -12,7 +12,6 @@ from flask import Flask from sqlalchemy.orm import Session from controllers.openapi.auth.data import AuthData -from extensions.ext_database import db from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx from models import Account, Tenant from services.account_service import AccountService, TenantService @@ -46,20 +45,25 @@ def make_account(db_session_with_containers: Session) -> Callable[..., Account]: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) if with_owner_tenant: - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist( + account, name=fake.company(), session=db_session_with_containers + ) return account return _make -def add_tenant_for_account(account: Account, *, role: str = "normal", name: str = "Second WS") -> Tenant: +def add_tenant_for_account( + account: Account, *, session: Session, role: str = "normal", name: str = "Second WS" +) -> Tenant: """Create an additional tenant and join ``account`` to it (real service calls).""" with patch("services.account_service.FeatureService") as mock_feature_service: mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True - tenant = TenantService.create_tenant(name=name) - TenantService.create_tenant_member(tenant, account, db.session, role=role) + tenant = TenantService.create_tenant(name=name, session=session) + TenantService.create_tenant_member(tenant, account, session, role=role) return tenant @@ -93,7 +97,7 @@ def account_auth_context( *, token_id: uuid.UUID, client_id: str = "integration-cli", -) -> Iterator[AuthContext]: +) -> Generator[AuthContext]: """Publish an account ``AuthContext`` for handlers that read ``get_auth_ctx()``. The auth pipeline normally sets this ContextVar; the integration suite diff --git a/api/tests/test_containers_integration_tests/controllers/openapi/test_account.py b/api/tests/test_containers_integration_tests/controllers/openapi/test_account.py index 77c812c0b34..7b5bef7b613 100644 --- a/api/tests/test_containers_integration_tests/controllers/openapi/test_account.py +++ b/api/tests/test_containers_integration_tests/controllers/openapi/test_account.py @@ -4,6 +4,7 @@ from collections.abc import Callable from inspect import unwrap from flask import Flask +from sqlalchemy.orm import Session from controllers.openapi.account import AccountApi from models import Account @@ -34,11 +35,13 @@ class TestAccountInfo: # the only workspace the account belongs to. assert result.default_workspace_id == owner_tenant.id - def test_lists_all_joined_workspaces(self, app: Flask, make_account: Callable[..., Account]) -> None: + def test_lists_all_joined_workspaces( + self, app: Flask, db_session_with_containers: Session, make_account: Callable[..., Account] + ) -> None: account = make_account() owner_tenant = account.current_tenant assert owner_tenant is not None - second = add_tenant_for_account(account, role="normal", name="Second WS") + second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS") api = AccountApi() with app.test_request_context("/openapi/v1/account"): diff --git a/api/tests/test_containers_integration_tests/controllers/openapi/test_app_dsl.py b/api/tests/test_containers_integration_tests/controllers/openapi/test_app_dsl.py index 12018c3c67c..93e8927cfef 100644 --- a/api/tests/test_containers_integration_tests/controllers/openapi/test_app_dsl.py +++ b/api/tests/test_containers_integration_tests/controllers/openapi/test_app_dsl.py @@ -81,8 +81,9 @@ def _app_and_account(db_session: Session, *, mode: str = "chat") -> tuple[App, A name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session) tenant = account.current_tenant assert tenant is not None app_args = CreateAppParams( diff --git a/api/tests/test_containers_integration_tests/controllers/openapi/test_workspaces.py b/api/tests/test_containers_integration_tests/controllers/openapi/test_workspaces.py index aed8c415454..18075704325 100644 --- a/api/tests/test_containers_integration_tests/controllers/openapi/test_workspaces.py +++ b/api/tests/test_containers_integration_tests/controllers/openapi/test_workspaces.py @@ -41,7 +41,7 @@ class TestWorkspacesList: account = make_account() owner_tenant = account.current_tenant assert owner_tenant is not None - second = add_tenant_for_account(account, role="normal", name="Second WS") + second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS") api = WorkspacesApi() with app.test_request_context("/openapi/v1/workspaces"): @@ -90,7 +90,9 @@ class TestWorkspaceSwitch: account = make_account() owner_tenant = account.current_tenant assert owner_tenant is not None - target = add_tenant_for_account(account, role="normal", name="Switch Target") + target = add_tenant_for_account( + account, session=db_session_with_containers, role="normal", name="Switch Target" + ) api = WorkspaceSwitchApi() with app.test_request_context(f"/openapi/v1/workspaces/{target.id}/switch", method="POST"): diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 2c6a9902401..d568a1c0b04 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -4,7 +4,7 @@ from __future__ import annotations import base64 from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -57,7 +57,7 @@ class TestForgotPasswordSendEmailApi: response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com") + mock_get_account.assert_called_once_with(ANY, "User@Example.com") mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans") mock_extract_ip.assert_called_once() mock_rate_limit.assert_called_once_with("127.0.0.1") @@ -177,7 +177,7 @@ class TestForgotPasswordResetApi: response = ForgotPasswordResetApi().post() assert response == {"result": "success"} - mock_get_account.assert_called_once_with("User@Example.com") + mock_get_account.assert_called_once_with(ANY, "User@Example.com") mock_update_account.assert_called_once() mock_revoke_token.assert_called_once_with("token-123") diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 9da6b04a2c4..c0da09278e3 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -27,8 +27,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -88,8 +89,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( @@ -141,8 +143,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( @@ -194,8 +197,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( @@ -257,8 +261,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( @@ -291,8 +296,11 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + account1, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company()) tenant1 = account1.current_tenant account2 = AccountService.create_account( @@ -300,8 +308,11 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + account2, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company()) tenant2 = account2.current_tenant # Create dataset for tenant1 @@ -367,8 +378,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Don't create any datasets @@ -391,8 +403,9 @@ class TestGetAvailableDatasetsIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create multiple datasets @@ -452,8 +465,9 @@ class TestKnowledgeRetrievalIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( @@ -520,8 +534,9 @@ class TestKnowledgeRetrievalIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset but no documents @@ -568,8 +583,9 @@ class TestKnowledgeRetrievalIntegration: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant dataset = Dataset( diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py index c93e61b2bfb..e2f8c8fc703 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -51,7 +51,7 @@ class TestApiKeyAuthService: self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id) assert len(result) >= 1 tenant_results = [r for r in result if r.tenant_id == tenant_id] @@ -61,7 +61,7 @@ class TestApiKeyAuthService: def test_get_provider_auth_list_empty( self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id ): - result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id) tenant_results = [r for r in result if r.tenant_id == tenant_id] assert tenant_results == [] @@ -74,7 +74,7 @@ class TestApiKeyAuthService: ) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id) tenant_results = [r for r in result if r.tenant_id == tenant_id] assert tenant_results == [] @@ -95,7 +95,7 @@ class TestApiKeyAuthService: mock_factory.return_value = mock_auth_instance mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args) mock_factory.assert_called_once() mock_auth_instance.validate_credentials.assert_called_once() @@ -118,7 +118,7 @@ class TestApiKeyAuthService: mock_auth_instance.validate_credentials.return_value = False mock_factory.return_value = mock_auth_instance - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args) db_session_with_containers.expire_all() bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() @@ -142,7 +142,7 @@ class TestApiKeyAuthService: original_key = mock_args["credentials"]["config"]["api_key"] - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args) assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123" assert mock_args["credentials"]["config"]["api_key"] != original_key @@ -166,14 +166,14 @@ class TestApiKeyAuthService: ) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider) assert result == mock_credentials def test_get_auth_credentials_not_found( self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, category, provider ): - result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider) assert result is None @@ -190,7 +190,7 @@ class TestApiKeyAuthService: ) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider) assert result == special_credentials assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" @@ -204,7 +204,7 @@ class TestApiKeyAuthService: binding_id = binding.id db_session_with_containers.expire_all() - ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id) + ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, binding_id) db_session_with_containers.expire_all() remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() @@ -214,7 +214,7 @@ class TestApiKeyAuthService: self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id ): # Should not raise when binding not found - ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) + ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, str(uuid4())) def test_validate_api_key_auth_args_success(self, mock_args): ApiKeyAuthService.validate_api_key_auth_args(mock_args) @@ -288,16 +288,16 @@ class TestApiKeyAuthService: mock_factory.return_value = mock_auth_instance mock_encrypter.encrypt_token.return_value = "encrypted_key" - with patch("services.auth.api_key_auth_service.db.session") as mock_session: - mock_session.commit.side_effect = Exception("Database error") - with pytest.raises(Exception, match="Database error"): - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + mock_session = MagicMock() + mock_session.commit.side_effect = Exception("Database error") + with pytest.raises(Exception, match="Database error"): + ApiKeyAuthService.create_provider_auth(mock_session, tenant_id, mock_args) @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") def test_create_provider_auth_factory_exception(self, mock_factory: MagicMock, tenant_id, mock_args): mock_factory.side_effect = Exception("Factory error") with pytest.raises(Exception, match="Factory error"): - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args) @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") @@ -307,7 +307,7 @@ class TestApiKeyAuthService: mock_factory.return_value = mock_auth_instance mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") with pytest.raises(Exception, match="Encryption error"): - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args) def test_validate_api_key_auth_args_none_input(self): with pytest.raises(TypeError): diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py index 1de9ce38a0b..9b86ab41f2b 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -13,6 +13,7 @@ import pytest from flask import Flask from sqlalchemy.orm import Session +from extensions.ext_database import db from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_factory import ApiKeyAuthFactory from services.auth.api_key_auth_service import ApiKeyAuthService @@ -56,7 +57,7 @@ class TestAuthIntegration: mock_encrypt.return_value = "encrypted_fc_test_key_123" args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args) mock_http.assert_called_once() call_args = mock_http.call_args @@ -100,15 +101,15 @@ class TestAuthIntegration: mock_encrypt.return_value = "encrypted_key" args1 = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_1, args1) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args1) args2 = {"category": category, "provider": AuthType.JINA, "credentials": jina_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_2, args2) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_2, args2) db_session_with_containers.expire_all() - result1 = ApiKeyAuthService.get_provider_auth_list(tenant_id_1) - result2 = ApiKeyAuthService.get_provider_auth_list(tenant_id_2) + result1 = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id_1) + result2 = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id_2) assert len(result1) == 1 assert result1[0].tenant_id == tenant_id_1 @@ -118,7 +119,9 @@ class TestAuthIntegration: def test_cross_tenant_access_prevention( self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id_2, category ): - result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) + result = ApiKeyAuthService.get_auth_credentials( + db_session_with_containers, tenant_id_2, category, AuthType.FIRECRAWL + ) assert result is None @@ -160,7 +163,7 @@ class TestAuthIntegration: "provider": AuthType.FIRECRAWL, "credentials": {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}, } - ApiKeyAuthService.create_provider_auth(tenant_id_1, thread_args) + ApiKeyAuthService.create_provider_auth(db.session(), tenant_id_1, thread_args) results.append("success") except Exception as e: exceptions.append(e) @@ -213,7 +216,7 @@ class TestAuthIntegration: args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} with pytest.raises(httpx.RequestError): - ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args) db_session_with_containers.expire_all() bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() @@ -250,11 +253,13 @@ class TestAuthIntegration: mock_encrypt.return_value = "encrypted_key" args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_auth_credentials(tenant_id_1, category, AuthType.FIRECRAWL) + result = ApiKeyAuthService.get_auth_credentials( + db_session_with_containers, tenant_id_1, category, AuthType.FIRECRAWL + ) assert result is not None assert result["config"]["api_key"] == "encrypted_key" diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index a2f5370cb76..65a5b0a96bf 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -64,12 +64,13 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) assert account.email == email assert account.status == AccountStatus.ACTIVE # Login with correct password - logged_in = AccountService.authenticate(email, password) + logged_in = AccountService.authenticate(email, password, session=db_session_with_containers) assert logged_in.id == account.id def test_create_account_without_password( @@ -90,6 +91,7 @@ class TestAccountService: name=name, interface_language="en-US", password=None, + session=db_session_with_containers, ) assert account.email == email assert account.password is None @@ -115,6 +117,7 @@ class TestAccountService: name=name, interface_language="en-US", password="invalid_new_password", + session=db_session_with_containers, ) def test_create_account_registration_disabled( @@ -135,6 +138,7 @@ class TestAccountService: name=name, interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) def test_create_account_email_in_freeze( @@ -158,6 +162,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) dify_config.BILLING_ENABLED = False # Reset config for other tests @@ -172,7 +177,7 @@ class TestAccountService: email = fake.email() password = generate_valid_password(fake) with pytest.raises(AccountPasswordError): - AccountService.authenticate(email, password) + AccountService.authenticate(email, password, session=db_session_with_containers) def test_authenticate_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -192,6 +197,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Ban the account @@ -200,7 +206,7 @@ class TestAccountService: db_session_with_containers.commit() with pytest.raises(AccountLoginError): - AccountService.authenticate(email, password) + AccountService.authenticate(email, password, session=db_session_with_containers) def test_authenticate_wrong_password(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -221,10 +227,11 @@ class TestAccountService: name=name, interface_language="en-US", password=correct_password, + session=db_session_with_containers, ) with pytest.raises(AccountPasswordError): - AccountService.authenticate(email, wrong_password) + AccountService.authenticate(email, wrong_password, session=db_session_with_containers) def test_authenticate_with_invite_token( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -246,6 +253,7 @@ class TestAccountService: name=name, interface_language="en-US", password=None, + session=db_session_with_containers, ) # Authenticate with invite token to set password @@ -253,6 +261,7 @@ class TestAccountService: email, new_password, invite_token="valid_invite_token", + session=db_session_with_containers, ) assert authenticated_account.id == account.id @@ -279,13 +288,14 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) account.status = AccountStatus.PENDING db_session_with_containers.commit() # Authenticate should activate the account - authenticated_account = AccountService.authenticate(email, password) + authenticated_account = AccountService.authenticate(email, password, session=db_session_with_containers) assert authenticated_account.status == AccountStatus.ACTIVE assert authenticated_account.initialized_at is not None @@ -310,13 +320,16 @@ class TestAccountService: name=name, interface_language="en-US", password=old_password, + session=db_session_with_containers, ) # Update password - updated_account = AccountService.update_account_password(account, old_password, new_password) + updated_account = AccountService.update_account_password( + account, old_password, new_password, session=db_session_with_containers + ) # Verify new password works - authenticated_account = AccountService.authenticate(email, new_password) + authenticated_account = AccountService.authenticate(email, new_password, session=db_session_with_containers) assert authenticated_account.id == account.id def test_update_account_password_wrong_current_password( @@ -341,10 +354,13 @@ class TestAccountService: name=name, interface_language="en-US", password=old_password, + session=db_session_with_containers, ) with pytest.raises(CurrentPasswordIncorrectError): - AccountService.update_account_password(account, wrong_password, new_password) + AccountService.update_account_password( + account, wrong_password, new_password, session=db_session_with_containers + ) def test_update_account_password_invalid_new_password( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -366,11 +382,12 @@ class TestAccountService: name=name, interface_language="en-US", password=old_password, + session=db_session_with_containers, ) # Test with too short password (assuming minimum length validation) with pytest.raises(ValueError): # Password validation error - AccountService.update_account_password(account, old_password, "123") + AccountService.update_account_password(account, old_password, "123", session=db_session_with_containers) def test_create_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -394,6 +411,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) assert account.email == email @@ -427,6 +445,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) def test_create_account_and_tenant_workspace_limit_exceeded( @@ -455,6 +474,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) def test_link_account_integrate_new_provider( @@ -476,10 +496,13 @@ class TestAccountService: name=name, interface_language="en-US", password=None, + session=db_session_with_containers, ) # Link with new provider - AccountService.link_account_integrate("new-google", "google_open_id_123", account) + AccountService.link_account_integrate( + "new-google", "google_open_id_123", account, session=db_session_with_containers + ) # Verify integration was created from models import AccountIntegrate @@ -511,13 +534,18 @@ class TestAccountService: name=name, interface_language="en-US", password=None, + session=db_session_with_containers, ) # Link with provider first time - AccountService.link_account_integrate("exists-google", "google_open_id_123", account) + AccountService.link_account_integrate( + "exists-google", "google_open_id_123", account, session=db_session_with_containers + ) # Link with same provider but different open_id (should update) - AccountService.link_account_integrate("exists-google", "google_open_id_456", account) + AccountService.link_account_integrate( + "exists-google", "google_open_id_456", account, session=db_session_with_containers + ) # Verify integration was updated from models import AccountIntegrate @@ -547,10 +575,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Close account - AccountService.close_account(account) + AccountService.close_account(account, session=db_session_with_containers) # Verify account status changed @@ -576,10 +605,13 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Update account fields - updated_account = AccountService.update_account(account, name=updated_name, interface_theme="dark") + updated_account = AccountService.update_account( + account, name=updated_name, interface_theme="dark", session=db_session_with_containers + ) assert updated_account.name == updated_name assert updated_account.interface_theme == "dark" @@ -604,10 +636,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) with pytest.raises(AttributeError): - AccountService.update_account(account, invalid_field="value") + AccountService.update_account(account, invalid_field="value", session=db_session_with_containers) def test_update_login_info(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -628,10 +661,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Update login info - AccountService.update_login_info(account, ip_address=ip_address) + AccountService.update_login_info(account, db_session_with_containers, ip_address=ip_address) # Verify login info was updated @@ -659,10 +693,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Login - token_pair = AccountService.login(account, ip_address=ip_address) + token_pair = AccountService.login(account, ip_address=ip_address, session=db_session_with_containers) assert isinstance(token_pair, TokenPair) assert token_pair.access_token == "mock_access_token" @@ -697,13 +732,14 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) account.status = AccountStatus.PENDING db_session_with_containers.commit() # Login should activate the account - token_pair = AccountService.login(account) + token_pair = AccountService.login(account, session=db_session_with_containers) db_session_with_containers.refresh(account) assert account.status == AccountStatus.ACTIVE @@ -727,10 +763,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Login first to get refresh token - token_pair = AccountService.login(account) + token_pair = AccountService.login(account, session=db_session_with_containers) # Logout AccountService.logout(account=account) @@ -761,15 +798,20 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create associated Tenant - TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + TenantService.create_owner_tenant_if_not_exist( + account=account, name=tenant_name, is_setup=True, session=db_session_with_containers + ) # Login to get initial tokens - initial_token_pair = AccountService.login(account) + initial_token_pair = AccountService.login(account, session=db_session_with_containers) # Refresh token - new_token_pair = AccountService.refresh_token(initial_token_pair.refresh_token) + new_token_pair = AccountService.refresh_token( + initial_token_pair.refresh_token, session=db_session_with_containers + ) assert isinstance(new_token_pair, TokenPair) assert new_token_pair.access_token == "new_mock_access_token" @@ -782,7 +824,7 @@ class TestAccountService: fake = Faker() invalid_token = fake.uuid4() with pytest.raises(ValueError, match="Invalid refresh token"): - AccountService.refresh_token(invalid_token) + AccountService.refresh_token(invalid_token, session=db_session_with_containers) def test_refresh_token_invalid_account( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -805,10 +847,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Login to get tokens - token_pair = AccountService.login(account) + token_pair = AccountService.login(account, session=db_session_with_containers) # Delete account @@ -817,7 +860,7 @@ class TestAccountService: # Try to refresh token with deleted account with pytest.raises(ValueError, match="Invalid account"): - AccountService.refresh_token(token_pair.refresh_token) + AccountService.refresh_token(token_pair.refresh_token, session=db_session_with_containers) def test_load_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -838,12 +881,15 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create associated Tenant - TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + TenantService.create_owner_tenant_if_not_exist( + account=account, name=tenant_name, is_setup=True, session=db_session_with_containers + ) # Load user - loaded_user = AccountService.load_user(account.id) + loaded_user = AccountService.load_user(account.id, db_session_with_containers) assert loaded_user is not None assert loaded_user.id == account.id @@ -855,7 +901,7 @@ class TestAccountService: """ fake = Faker() non_existent_user_id = fake.uuid4() - loaded_user = AccountService.load_user(non_existent_user_id) + loaded_user = AccountService.load_user(non_existent_user_id, db_session_with_containers) assert loaded_user is None def test_load_user_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): @@ -876,6 +922,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Ban the account @@ -884,7 +931,7 @@ class TestAccountService: db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception - AccountService.load_user(account.id) + AccountService.load_user(account.id, db_session_with_containers) def test_get_account_jwt_token(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -905,6 +952,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate JWT token @@ -939,12 +987,17 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create associated Tenant - TenantService.create_owner_tenant_if_not_exist(account=account, name=tenant_name, is_setup=True) + TenantService.create_owner_tenant_if_not_exist( + account=account, name=tenant_name, is_setup=True, session=db_session_with_containers + ) # Load logged in account - loaded_account = AccountService.load_logged_in_account(account_id=account.id) + loaded_account = AccountService.load_logged_in_account( + account_id=account.id, session=db_session_with_containers + ) assert loaded_account is not None assert loaded_account.id == account.id @@ -969,10 +1022,11 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Get user through email - found_user = AccountService.get_user_through_email(email) + found_user = AccountService.get_user_through_email(email, session=db_session_with_containers) assert found_user is not None assert found_user.id == account.id @@ -986,7 +1040,7 @@ class TestAccountService: fake = Faker() domain = f"test-{fake.random_letters(10)}.com" non_existent_email = fake.email(domain=domain) - found_user = AccountService.get_user_through_email(non_existent_email) + found_user = AccountService.get_user_through_email(non_existent_email, session=db_session_with_containers) assert found_user is None def test_get_user_through_email_banned_account( @@ -1009,6 +1063,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Ban the account @@ -1017,7 +1072,7 @@ class TestAccountService: db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception - AccountService.get_user_through_email(email) + AccountService.get_user_through_email(email, session=db_session_with_containers) def test_get_user_through_email_in_freeze( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1032,7 +1087,7 @@ class TestAccountService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True with pytest.raises(AccountRegisterError): - AccountService.get_user_through_email(email_in_freeze) + AccountService.get_user_through_email(email_in_freeze, session=db_session_with_containers) # Reset config dify_config.BILLING_ENABLED = False @@ -1055,6 +1110,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) with ( @@ -1092,6 +1148,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate verification code @@ -1122,6 +1179,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate verification code @@ -1152,6 +1210,7 @@ class TestAccountService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate verification code @@ -1206,7 +1265,7 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) assert tenant.name == tenant_name assert tenant.plan == "basic" @@ -1227,7 +1286,7 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = False with pytest.raises(NotAllowedCreateWorkspace): # NotAllowedCreateWorkspace exception - TenantService.create_tenant(name=tenant_name) + TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) def test_create_tenant_with_custom_name( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1243,7 +1302,9 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = False # Create tenant with setup flag (should bypass workspace creation restriction) - tenant = TenantService.create_tenant(name=custom_tenant_name, is_setup=True, is_from_dashboard=True) + tenant = TenantService.create_tenant( + name=custom_tenant_name, is_setup=True, is_from_dashboard=True, session=db_session_with_containers + ) assert tenant.name == custom_tenant_name assert tenant.plan == "basic" @@ -1267,12 +1328,13 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create tenant member @@ -1302,18 +1364,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account1 = AccountService.create_account( email=email1, name=name1, interface_language="en-US", password=password1, + session=db_session_with_containers, ) account2 = AccountService.create_account( email=email2, name=name2, interface_language="en-US", password=password2, + session=db_session_with_containers, ) # Create first owner @@ -1340,12 +1404,13 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create member with initial role @@ -1379,16 +1444,17 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) - tenant1 = TenantService.create_tenant(name=tenant1_name) - tenant2 = TenantService.create_tenant(name=tenant2_name) + tenant1 = TenantService.create_tenant(name=tenant1_name, session=db_session_with_containers) + tenant2 = TenantService.create_tenant(name=tenant2_name, session=db_session_with_containers) # Add account to both tenants TenantService.create_tenant_member(tenant1, account, db_session_with_containers, role="normal") TenantService.create_tenant_member(tenant2, account, db_session_with_containers, role="admin") # Get join tenants - join_tenants = TenantService.get_join_tenants(account) + join_tenants = TenantService.get_join_tenants(account, session=db_session_with_containers) assert len(join_tenants) == 2 tenant_names = [tenant.name for tenant in join_tenants] @@ -1417,8 +1483,9 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) # Add account to tenant and set as current TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="owner") @@ -1427,7 +1494,7 @@ class TestTenantService: db_session_with_containers.commit() # Get current tenant - current_tenant = TenantService.get_current_tenant_by_account(account) + current_tenant = TenantService.get_current_tenant_by_account(account, session=db_session_with_containers) assert current_tenant.id == tenant.id assert current_tenant.name == tenant.name @@ -1454,11 +1521,12 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Try to get current tenant (should fail) with pytest.raises((AttributeError, TenantNotFoundError)): - TenantService.get_current_tenant_by_account(account) + TenantService.get_current_tenant_by_account(account, session=db_session_with_containers) def test_switch_tenant_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1481,9 +1549,10 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) - tenant1 = TenantService.create_tenant(name=tenant1_name) - tenant2 = TenantService.create_tenant(name=tenant2_name) + tenant1 = TenantService.create_tenant(name=tenant1_name, session=db_session_with_containers) + tenant2 = TenantService.create_tenant(name=tenant2_name, session=db_session_with_containers) # Add account to both tenants TenantService.create_tenant_member(tenant1, account, db_session_with_containers, role="owner") @@ -1495,7 +1564,7 @@ class TestTenantService: db_session_with_containers.commit() # Switch to second tenant - TenantService.switch_tenant(account, tenant2.id) + TenantService.switch_tenant(account, tenant2.id, session=db_session_with_containers) # Verify tenant was switched db_session_with_containers.refresh(account) @@ -1520,11 +1589,12 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Try to switch tenant without providing tenant ID with pytest.raises(ValueError, match="Tenant ID must be provided"): - TenantService.switch_tenant(account, None) + TenantService.switch_tenant(account, None, session=db_session_with_containers) def test_switch_tenant_account_not_member( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1548,12 +1618,13 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) # Try to switch to tenant where account is not a member with pytest.raises(Exception, match="Tenant not found or account is not a member of the tenant"): - TenantService.switch_tenant(account, tenant.id) + TenantService.switch_tenant(account, tenant.id, session=db_session_with_containers) def test_has_roles_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1573,18 +1644,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) admin_account = AccountService.create_account( email=admin_email, name=admin_name, interface_language="en-US", password=admin_password, + session=db_session_with_containers, ) # Add members with different roles @@ -1594,15 +1667,15 @@ class TestTenantService: # Check if tenant has owner role from models.account import TenantAccountRole - has_owner = TenantService.has_roles(tenant, [TenantAccountRole.OWNER]) + has_owner = TenantService.has_roles(tenant, [TenantAccountRole.OWNER], session=db_session_with_containers) assert has_owner is True # Check if tenant has admin role - has_admin = TenantService.has_roles(tenant, [TenantAccountRole.ADMIN]) + has_admin = TenantService.has_roles(tenant, [TenantAccountRole.ADMIN], session=db_session_with_containers) assert has_admin is True # Check if tenant has normal role (should be False) - has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL]) + has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL], session=db_session_with_containers) assert has_normal is False def test_has_roles_invalid_role_type(self, db_session_with_containers: Session, mock_external_service_dependencies): @@ -1618,11 +1691,11 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) # Try to check roles with invalid role type with pytest.raises(ValueError, match="all roles must be TenantAccountRole"): - TenantService.has_roles(tenant, [invalid_role]) + TenantService.has_roles(tenant, [invalid_role], session=db_session_with_containers) def test_get_user_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1639,19 +1712,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Add account to tenant with specific role TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="editor") # Get user role - user_role = TenantService.get_user_role(account, tenant) + user_role = TenantService.get_user_role(account, tenant, session=db_session_with_containers) assert user_role == "editor" @@ -1675,18 +1749,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) member_account = AccountService.create_account( email=member_email, name=member_name, interface_language="en-US", password=member_password, + session=db_session_with_containers, ) # Add members with different roles @@ -1694,7 +1770,9 @@ class TestTenantService: TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="normal") # Check owner permission to add member (should succeed) - TenantService.check_member_permission(tenant, owner_account, member_account, "add") + TenantService.check_member_permission( + tenant, owner_account, member_account, "add", session=db_session_with_containers + ) def test_check_member_permission_invalid_action( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1714,12 +1792,13 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Add account to tenant @@ -1727,7 +1806,9 @@ class TestTenantService: # Try to check permission with invalid action with pytest.raises(Exception, match="Invalid action"): - TenantService.check_member_permission(tenant, account, None, invalid_action) + TenantService.check_member_permission( + tenant, account, None, invalid_action, session=db_session_with_containers + ) def test_check_member_permission_operate_self( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1746,12 +1827,13 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Add account to tenant @@ -1759,7 +1841,9 @@ class TestTenantService: # Try to check permission to operate self with pytest.raises(Exception, match="Cannot operate self"): - TenantService.check_member_permission(tenant, account, account, "remove") + TenantService.check_member_permission( + tenant, account, account, "remove", session=db_session_with_containers + ) def test_remove_member_from_tenant_success( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1781,18 +1865,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) member_account = AccountService.create_account( email=member_email, name=member_name, interface_language="en-US", password=member_password, + session=db_session_with_containers, ) # Add members with different roles @@ -1827,7 +1913,9 @@ class TestTenantService: ): mock_sync.return_value = True - TenantService.remove_member_from_tenant(tenant, member_account, owner_account) + TenantService.remove_member_from_tenant( + tenant, member_account, owner_account, session=db_session_with_containers + ) # Verify sync was called mock_sync.assert_called_once_with( @@ -1867,12 +1955,13 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Add account to tenant @@ -1880,7 +1969,7 @@ class TestTenantService: # Try to remove self with pytest.raises(Exception, match="Cannot operate self"): - TenantService.remove_member_from_tenant(tenant, account, account) + TenantService.remove_member_from_tenant(tenant, account, account, session=db_session_with_containers) def test_remove_member_from_tenant_not_member( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1902,18 +1991,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) non_member_account = AccountService.create_account( email=non_member_email, name=non_member_name, interface_language="en-US", password=non_member_password, + session=db_session_with_containers, ) # Add only owner to tenant @@ -1921,7 +2012,9 @@ class TestTenantService: # Try to remove non-member with pytest.raises(Exception, match="Member not in tenant"): - TenantService.remove_member_from_tenant(tenant, non_member_account, owner_account) + TenantService.remove_member_from_tenant( + tenant, non_member_account, owner_account, session=db_session_with_containers + ) def test_update_member_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1941,18 +2034,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) member_account = AccountService.create_account( email=member_email, name=member_name, interface_language="en-US", password=member_password, + session=db_session_with_containers, ) # Add members with different roles @@ -1960,7 +2055,9 @@ class TestTenantService: TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="normal") # Update member role - TenantService.update_member_role(tenant, member_account, "admin", owner_account) + TenantService.update_member_role( + tenant, member_account, "admin", owner_account, session=db_session_with_containers + ) # Verify role was updated from models.account import TenantAccountJoin @@ -1990,18 +2087,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) member_account = AccountService.create_account( email=member_email, name=member_name, interface_language="en-US", password=member_password, + session=db_session_with_containers, ) # Add members with different roles @@ -2009,7 +2108,9 @@ class TestTenantService: TenantService.create_tenant_member(tenant, member_account, db_session_with_containers, role="admin") # Update member role to owner - TenantService.update_member_role(tenant, member_account, "owner", owner_account) + TenantService.update_member_role( + tenant, member_account, "owner", owner_account, session=db_session_with_containers + ) # Verify roles were updated correctly from models.account import TenantAccountJoin @@ -2047,18 +2148,20 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) member_account = AccountService.create_account( email=member_email, name=member_name, interface_language="en-US", password=member_password, + session=db_session_with_containers, ) # Add members with different roles @@ -2067,7 +2170,9 @@ class TestTenantService: # Try to update member role to already assigned role with pytest.raises(Exception, match="The provided role is already assigned to the member"): - TenantService.update_member_role(tenant, member_account, "admin", owner_account) + TenantService.update_member_role( + tenant, member_account, "admin", owner_account, session=db_session_with_containers + ) def test_get_tenant_count_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -2083,12 +2188,12 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create multiple tenants - tenant1 = TenantService.create_tenant(name=tenant1_name) - tenant2 = TenantService.create_tenant(name=tenant2_name) - tenant3 = TenantService.create_tenant(name=tenant3_name) + tenant1 = TenantService.create_tenant(name=tenant1_name, session=db_session_with_containers) + tenant2 = TenantService.create_tenant(name=tenant2_name, session=db_session_with_containers) + tenant3 = TenantService.create_tenant(name=tenant3_name, session=db_session_with_containers) # Get tenant count - tenant_count = TenantService.get_tenant_count() + tenant_count = TenantService.get_tenant_count(session=db_session_with_containers) # Should have at least 3 tenants (may be more from other tests) assert tenant_count >= 3 @@ -2118,10 +2223,11 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create owner tenant - TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) + TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name, session=db_session_with_containers) # Verify tenant was created and linked from models.account import TenantAccountJoin @@ -2158,15 +2264,18 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) - existing_tenant = TenantService.create_tenant(name=existing_tenant_name) + existing_tenant = TenantService.create_tenant(name=existing_tenant_name, session=db_session_with_containers) TenantService.create_tenant_member(existing_tenant, account, db_session_with_containers, role="owner") account.current_tenant = existing_tenant db_session_with_containers.commit() # Try to create owner tenant again (should not create new one) - TenantService.create_owner_tenant_if_not_exist(account, name=new_workspace_name) + TenantService.create_owner_tenant_if_not_exist( + account, name=new_workspace_name, session=db_session_with_containers + ) # Verify no new tenant was created tenant_joins = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).all() @@ -2195,11 +2304,14 @@ class TestTenantService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Try to create owner tenant (should fail) with pytest.raises(WorkSpaceNotAllowedCreateError): # WorkSpaceNotAllowedCreateError exception - TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) + TenantService.create_owner_tenant_if_not_exist( + account, name=workspace_name, session=db_session_with_containers + ) def test_get_tenant_members_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -2222,24 +2334,27 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) admin_account = AccountService.create_account( email=admin_email, name=admin_name, interface_language="en-US", password=admin_password, + session=db_session_with_containers, ) normal_account = AccountService.create_account( email=normal_email, name=normal_name, interface_language="en-US", password=normal_password, + session=db_session_with_containers, ) # Add members with different roles @@ -2248,7 +2363,7 @@ class TestTenantService: TenantService.create_tenant_member(tenant, normal_account, db_session_with_containers, role="normal") # Get tenant members - members = TenantService.get_tenant_members(tenant) + members = TenantService.get_tenant_members(tenant, session=db_session_with_containers) assert len(members) == 3 member_emails = [member.email for member in members] @@ -2288,24 +2403,27 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant and accounts - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) owner_account = AccountService.create_account( email=owner_email, name=owner_name, interface_language="en-US", password=owner_password, + session=db_session_with_containers, ) dataset_operator_account = AccountService.create_account( email=operator_email, name=operator_name, interface_language="en-US", password=operator_password, + session=db_session_with_containers, ) normal_account = AccountService.create_account( email=normal_email, name=normal_name, interface_language="en-US", password=normal_password, + session=db_session_with_containers, ) # Add members with different roles @@ -2316,7 +2434,7 @@ class TestTenantService: TenantService.create_tenant_member(tenant, normal_account, db_session_with_containers, role="normal") # Get dataset operator members - dataset_operators = TenantService.get_dataset_operator_members(tenant) + dataset_operators = TenantService.get_dataset_operator_members(tenant, session=db_session_with_containers) assert len(dataset_operators) == 1 assert dataset_operators[0].email == operator_email @@ -2336,7 +2454,7 @@ class TestTenantService: ].get_system_features.return_value.is_allow_create_workspace = True # Create tenant with custom config - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) # Set custom config custom_config = {"theme": theme, "language": language, "feature_flags": {"beta": True}} @@ -2402,6 +2520,7 @@ class TestRegisterService: password=admin_password, ip_address=ip_address, language="en-US", + session=db_session_with_containers, ) # Verify account was created @@ -2450,6 +2569,7 @@ class TestRegisterService: password=admin_password, ip_address=ip_address, language="en-US", + session=db_session_with_containers, ) # Verify no entities were created (rollback worked) @@ -2491,6 +2611,7 @@ class TestRegisterService: name=name, password=password, language=language, + session=db_session_with_containers, ) # Verify account was created @@ -2536,6 +2657,7 @@ class TestRegisterService: open_id=open_id, provider=provider, language=language, + session=db_session_with_containers, ) # Verify account was created @@ -2585,6 +2707,7 @@ class TestRegisterService: password=password, language=language, status=AccountStatus.PENDING, + session=db_session_with_containers, ) # Verify account was created with pending status @@ -2624,6 +2747,7 @@ class TestRegisterService: name=name, password=password, language=language, + session=db_session_with_containers, ) # Verify account was created with no tenant @@ -2665,6 +2789,7 @@ class TestRegisterService: name=name, password=password, language=language, + session=db_session_with_containers, ) # Verify account was created with no tenant @@ -2699,6 +2824,7 @@ class TestRegisterService: password=password, language=language, create_workspace_required=False, + session=db_session_with_containers, ) # Verify account was created @@ -2737,12 +2863,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and inviter account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) inviter = AccountService.create_account( email=inviter_email, name=inviter_name, interface_language="en-US", password=inviter_password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") @@ -2757,6 +2884,7 @@ class TestRegisterService: language=language, role="normal", inviter=inviter, + session=db_session_with_containers, ) # Verify token was generated @@ -2803,12 +2931,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and inviter account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) inviter = AccountService.create_account( email=inviter_email, name=inviter_name, interface_language="en-US", password=inviter_password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") @@ -2818,6 +2947,7 @@ class TestRegisterService: name=existing_member_name, interface_language="en-US", password=existing_member_password, + session=db_session_with_containers, ) # Mock the email task @@ -2830,6 +2960,7 @@ class TestRegisterService: language=language, role="admin", inviter=inviter, + session=db_session_with_containers, ) assert token is not None @@ -2846,7 +2977,9 @@ class TestRegisterService: ) assert tenant_join is None - invitation = RegisterService.get_invitation_if_token_valid(None, None, token) + invitation = RegisterService.get_invitation_if_token_valid( + None, None, token, session=db_session_with_containers + ) assert invitation is not None assert invitation["account"].id == existing_account.id assert invitation["data"]["role"] == "admin" @@ -2872,12 +3005,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and inviter account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) inviter = AccountService.create_account( email=inviter_email, name=inviter_name, interface_language="en-US", password=inviter_password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") @@ -2887,6 +3021,7 @@ class TestRegisterService: name=existing_pending_member_name, interface_language="en-US", password=existing_pending_member_password, + session=db_session_with_containers, ) existing_account.status = AccountStatus.PENDING @@ -2906,6 +3041,7 @@ class TestRegisterService: language=language, role="normal", inviter=inviter, + session=db_session_with_containers, ) # Verify token was generated @@ -2930,7 +3066,7 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) # Execute invitation without inviter (should fail) with pytest.raises(ValueError, match="Inviter is required"): @@ -2940,6 +3076,7 @@ class TestRegisterService: language=language, role="normal", inviter=None, + session=db_session_with_containers, ) def test_invite_new_member_account_already_in_tenant( @@ -2962,12 +3099,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and inviter account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) inviter = AccountService.create_account( email=inviter_email, name=inviter_name, interface_language="en-US", password=inviter_password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, inviter, db_session_with_containers, role="owner") @@ -2977,6 +3115,7 @@ class TestRegisterService: name=already_in_tenant_name, interface_language="en-US", password=already_in_tenant_password, + session=db_session_with_containers, ) existing_account.status = AccountStatus.ACTIVE @@ -2993,6 +3132,7 @@ class TestRegisterService: language=language, role="normal", inviter=inviter, + session=db_session_with_containers, ) def test_generate_invite_token_success( @@ -3011,12 +3151,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Execute token generation @@ -3055,12 +3196,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate a real token @@ -3102,12 +3244,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate a real token @@ -3145,12 +3288,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Generate a real token @@ -3188,12 +3332,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="normal") @@ -3211,6 +3356,7 @@ class TestRegisterService: workspace_id=tenant.id, email=account.email, token=token, + session=db_session_with_containers, ) # Verify result contains expected data @@ -3236,6 +3382,7 @@ class TestRegisterService: workspace_id=workspace_id, email=email, token=invalid_token, + session=db_session_with_containers, ) # Verify result is None @@ -3263,6 +3410,7 @@ class TestRegisterService: name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) # Create a real token but with non-existent tenant ID @@ -3283,6 +3431,7 @@ class TestRegisterService: workspace_id=invalid_tenant_id, email=account.email, token=token, + session=db_session_with_containers, ) # Verify result is None (tenant not found) @@ -3308,12 +3457,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="normal") @@ -3333,6 +3483,7 @@ class TestRegisterService: workspace_id=tenant.id, email=account.email, token=token, + session=db_session_with_containers, ) # Verify result is None (account ID mismatch) @@ -3358,12 +3509,13 @@ class TestRegisterService: mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False # Create tenant and account - tenant = TenantService.create_tenant(name=tenant_name) + tenant = TenantService.create_tenant(name=tenant_name, session=db_session_with_containers) account = AccountService.create_account( email=email, name=name, interface_language="en-US", password=password, + session=db_session_with_containers, ) TenantService.create_tenant_member(tenant, account, db_session_with_containers, role="normal") @@ -3390,6 +3542,7 @@ class TestRegisterService: workspace_id=tenant.id, email=account.email, token=token, + session=db_session_with_containers, ) # Verify result is None (tenant not in normal status) diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 21a768e3446..0ee0cb84e75 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -114,8 +114,9 @@ class TestAgentService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app with realistic data diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index bc75562d159..6c6b9338d7d 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -81,8 +81,9 @@ class TestAnnotationService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Setup app creation arguments @@ -280,7 +281,9 @@ class TestAnnotationService: "question": fake.sentence(), "answer": fake.text(max_nb_chars=200), } - updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + updated_annotation = AppAnnotationService.update_app_annotation_directly( + updated_args, app.id, annotation.id, db_session_with_containers + ) # Verify annotation was updated correctly assert updated_annotation.id == annotation.id @@ -567,7 +570,7 @@ class TestAnnotationService: annotation_id = annotation.id # Delete the annotation - AppAnnotationService.delete_app_annotation(app.id, annotation_id) + AppAnnotationService.delete_app_annotation(app.id, annotation_id, db_session_with_containers) # Verify annotation was deleted @@ -595,7 +598,7 @@ class TestAnnotationService: # Try to delete annotation with non-existent app with pytest.raises(NotFound, match="App not found"): - AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) + AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id, db_session_with_containers) def test_delete_app_annotation_annotation_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -609,7 +612,7 @@ class TestAnnotationService: # Try to delete non-existent annotation with pytest.raises(NotFound, match="Annotation not found"): - AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) + AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id, db_session_with_containers) def test_enable_app_annotation_success( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1225,7 +1228,9 @@ class TestAnnotationService: "question": fake.sentence(), "answer": fake.text(max_nb_chars=200), } - updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id) + updated_annotation = AppAnnotationService.update_app_annotation_directly( + updated_args, app.id, annotation.id, db_session_with_containers + ) # Verify annotation was updated correctly assert updated_annotation.id == annotation.id @@ -1295,7 +1300,7 @@ class TestAnnotationService: mock_external_service_dependencies["delete_task"].delay.reset_mock() # Delete the annotation - AppAnnotationService.delete_app_annotation(app.id, annotation_id) + AppAnnotationService.delete_app_annotation(app.id, annotation_id, db_session_with_containers) # Verify annotation was deleted deleted_annotation = ( diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 8bd4069639f..1f88ce90621 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -57,8 +57,9 @@ class TestAPIBasedExtensionService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant return account, tenant diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index ff74ca3039e..cee08c4c33e 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -145,8 +145,11 @@ class TestAppDslService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + account, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant app_args = CreateAppParams( name=fake.company(), diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index f8482f99c00..fce0d26c484 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -166,8 +166,9 @@ class TestAppGenerateService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from services.app_service import AppService, CreateAppParams diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 0f5cd184430..f9df99c5594 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -62,8 +62,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Setup app creation arguments @@ -119,8 +120,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency @@ -162,8 +164,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -210,8 +213,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency @@ -261,8 +265,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from services.app_service import AppListParams, AppService, CreateAppParams @@ -344,8 +349,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from models import AppStar @@ -404,8 +410,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from services.app_service import AppService, CreateAppParams, StarredAppListParams @@ -500,8 +507,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency @@ -566,14 +574,18 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + first_account, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(first_account, name=fake.company()) tenant = first_account.current_tenant second_account = AccountService.create_account( email=fake.email(), name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) from services.app_service import AppListParams, AppService, CreateAppParams @@ -623,8 +635,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency @@ -685,8 +698,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -755,8 +769,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from services.app_service import AppService, CreateAppParams @@ -807,8 +822,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant from services.app_service import AppService, CreateAppParams @@ -857,8 +873,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -910,8 +927,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -971,8 +989,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1030,8 +1049,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1089,8 +1109,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1139,8 +1160,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1190,8 +1212,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1249,8 +1272,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1287,8 +1311,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1326,8 +1351,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app first @@ -1375,8 +1401,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) # Import here to avoid circular dependency from services.app_service import CreateAppParams @@ -1411,8 +1438,9 @@ class TestAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency diff --git a/api/tests/test_containers_integration_tests/services/test_audio_service_db.py b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py index 2593b53fe84..c9cf60bcfb1 100644 --- a/api/tests/test_containers_integration_tests/services/test_audio_service_db.py +++ b/api/tests/test_containers_integration_tests/services/test_audio_service_db.py @@ -158,6 +158,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: with patch("services.audio_service.ModelManager.for_tenant", return_value=mock_model_manager): result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id=message.id, voice="en-US-Neural", ) @@ -174,6 +175,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id="invalid-uuid", ) @@ -185,6 +187,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id=str(uuid4()), ) @@ -205,6 +208,7 @@ class TestAudioServiceTranscriptTTSMessageLookup: result = AudioService.transcript_tts( app_model=app, + session=db_session_with_containers, message_id=message.id, ) diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index e4fd81b53e7..c2b0385fc74 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -97,8 +97,9 @@ class TestFeedbackService: ) # Test CSV export - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv") - + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="csv" + ) # Verify response structure assert hasattr(result, "headers") assert "text/csv" in result.headers["Content-Type"] @@ -128,7 +129,9 @@ class TestFeedbackService: ) # Test JSON export - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json") + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="json" + ) # Verify response structure assert hasattr(result, "headers") @@ -158,8 +161,8 @@ class TestFeedbackService: # Test with filters result = FeedbackService.export_feedbacks( - mock_db_session, app_id=sample_data["app"].id, + session=mock_db_session, from_source=FeedbackFromSource.ADMIN, rating=FeedbackRating.DISLIKE, has_comment=True, @@ -175,7 +178,9 @@ class TestFeedbackService: """Test exporting feedback when no data exists.""" mock_db_session.execute.return_value = _execute_result([]) - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv") + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="csv" + ) # Should return an empty CSV with headers only assert hasattr(result, "headers") @@ -194,13 +199,13 @@ class TestFeedbackService: # Test with invalid start_date with pytest.raises(ValueError, match="Invalid start_date format"): FeedbackService.export_feedbacks( - mock_db_session, app_id=sample_data["app"].id, start_date="invalid-date-format" + app_id=sample_data["app"].id, session=mock_db_session, start_date="invalid-date-format" ) # Test with invalid end_date with pytest.raises(ValueError, match="Invalid end_date format"): FeedbackService.export_feedbacks( - mock_db_session, app_id=sample_data["app"].id, end_date="invalid-date-format" + app_id=sample_data["app"].id, session=mock_db_session, end_date="invalid-date-format" ) def test_export_feedbacks_invalid_format(self, mock_db_session, sample_data): @@ -208,8 +213,8 @@ class TestFeedbackService: with pytest.raises(ValueError, match="Unsupported format"): FeedbackService.export_feedbacks( - mock_db_session, app_id=sample_data["app"].id, + session=mock_db_session, format_type="xml", # Unsupported format ) @@ -239,7 +244,9 @@ class TestFeedbackService: ) # Test export - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json") + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="json" + ) # Check JSON content json_content = json.loads(result.get_data(as_text=True)) @@ -290,7 +297,9 @@ class TestFeedbackService: ) # Test export - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="csv") + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="csv" + ) # Check that unicode content is preserved csv_content = result.get_data(as_text=True) @@ -320,7 +329,9 @@ class TestFeedbackService: ) # Test export - result = FeedbackService.export_feedbacks(mock_db_session, app_id=sample_data["app"].id, format_type="json") + result = FeedbackService.export_feedbacks( + app_id=sample_data["app"].id, session=mock_db_session, format_type="json" + ) # Check JSON content for emoji ratings json_content = json.loads(result.get_data(as_text=True)) diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index 6d0d281c6ba..f2d682be3bf 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -98,8 +98,9 @@ class TestMessageService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Setup app creation arguments @@ -648,8 +649,11 @@ class TestMessageService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + other_account, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company()) # Test getting message with different user with pytest.raises(MessageNotExistsError): diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py index 5844441e6a5..fbdc265265d 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -95,7 +95,7 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - MetadataService.update_documents_metadata(dataset, metadata_args, current_account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, metadata_args, current_account) db_session_with_containers.expire_all() updated_doc = db_session_with_containers.get(Document, document.id) @@ -126,7 +126,7 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - MetadataService.update_documents_metadata(dataset, metadata_args, current_account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, metadata_args, current_account) db_session_with_containers.expire_all() updated_doc = db_session_with_containers.get(Document, document.id) @@ -168,7 +168,7 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - MetadataService.update_documents_metadata(dataset, metadata_args, current_account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, metadata_args, current_account) db_session_with_containers.expire_all() bindings = db_session_with_containers.scalars( @@ -202,6 +202,8 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - with patch("services.metadata_service.db.session.commit", side_effect=RuntimeError("database connection lost")): + with patch.object(db_session_with_containers, "commit", side_effect=RuntimeError("database connection lost")): with pytest.raises(RuntimeError, match="database connection lost"): - MetadataService.update_documents_metadata(dataset, metadata_args, current_account) + MetadataService.update_documents_metadata( + db_session_with_containers, dataset, metadata_args, current_account + ) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 0c9e3830430..7cc9fc7e696 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -183,7 +183,9 @@ class TestMetadataService: metadata_args = MetadataArgs(type="string", name="test_metadata") # Act: Execute the method under test - result = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + result = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Assert: Verify the expected outcomes assert result is not None @@ -218,7 +220,7 @@ class TestMetadataService: # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): - MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + MetadataService.create_metadata(db_session_with_containers, dataset.id, metadata_args, account, tenant.id) def test_create_metadata_name_already_exists( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -236,14 +238,16 @@ class TestMetadataService: # Create first metadata first_metadata_args = MetadataArgs(type="string", name="duplicate_name") - MetadataService.create_metadata(dataset.id, first_metadata_args, account, tenant.id) + MetadataService.create_metadata(db_session_with_containers, dataset.id, first_metadata_args, account, tenant.id) # Try to create second metadata with same name second_metadata_args = MetadataArgs(type="number", name="duplicate_name") # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists."): - MetadataService.create_metadata(dataset.id, second_metadata_args, account, tenant.id) + MetadataService.create_metadata( + db_session_with_containers, dataset.id, second_metadata_args, account, tenant.id + ) def test_create_metadata_name_conflicts_with_built_in_field( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -265,7 +269,7 @@ class TestMetadataService: # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): - MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + MetadataService.create_metadata(db_session_with_containers, dataset.id, metadata_args, account, tenant.id) def test_update_metadata_name_success( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -283,11 +287,15 @@ class TestMetadataService: # Create metadata first metadata_args = MetadataArgs(type="string", name="old_name") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Act: Execute the method under test new_name = "new_name" - result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name, account, tenant.id) + result = MetadataService.update_metadata_name( + db_session_with_containers, dataset.id, metadata.id, new_name, account, tenant.id + ) # Assert: Verify the expected outcomes assert result is not None @@ -316,14 +324,18 @@ class TestMetadataService: # Create metadata first metadata_args = MetadataArgs(type="string", name="old_name") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Try to update with too long name long_name = "a" * 256 # 256 characters, exceeding 255 limit # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): - MetadataService.update_metadata_name(dataset.id, metadata.id, long_name, account, tenant.id) + MetadataService.update_metadata_name( + db_session_with_containers, dataset.id, metadata.id, long_name, account, tenant.id + ) def test_update_metadata_name_already_exists( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -341,14 +353,20 @@ class TestMetadataService: # Create two metadata entries first_metadata_args = MetadataArgs(type="string", name="first_metadata") - first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args, account, tenant.id) + first_metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, first_metadata_args, account, tenant.id + ) second_metadata_args = MetadataArgs(type="number", name="second_metadata") - second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args, account, tenant.id) + second_metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, second_metadata_args, account, tenant.id + ) # Try to update first metadata with second metadata's name with pytest.raises(ValueError, match="Metadata name already exists."): - MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata", account, tenant.id) + MetadataService.update_metadata_name( + db_session_with_containers, dataset.id, first_metadata.id, "second_metadata", account, tenant.id + ) def test_update_metadata_name_conflicts_with_built_in_field( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -366,13 +384,17 @@ class TestMetadataService: # Create metadata first metadata_args = MetadataArgs(type="string", name="old_name") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Try to update with built-in field name built_in_field_name = BuiltInField.document_name with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): - MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name, account, tenant.id) + MetadataService.update_metadata_name( + db_session_with_containers, dataset.id, metadata.id, built_in_field_name, account, tenant.id + ) def test_update_metadata_name_not_found( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -395,7 +417,9 @@ class TestMetadataService: new_name = "new_name" # Act: Execute the method under test - result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name, account, tenant.id) + result = MetadataService.update_metadata_name( + db_session_with_containers, dataset.id, fake_metadata_id, new_name, account, tenant.id + ) # Assert: Verify the method returns None when metadata is not found assert result is None @@ -416,10 +440,12 @@ class TestMetadataService: # Create metadata first metadata_args = MetadataArgs(type="string", name="to_be_deleted") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Act: Execute the method under test - result = MetadataService.delete_metadata(dataset.id, metadata.id) + result = MetadataService.delete_metadata(db_session_with_containers, dataset.id, metadata.id) # Assert: Verify the expected outcomes assert result is not None @@ -450,7 +476,7 @@ class TestMetadataService: fake_metadata_id = str(uuid.uuid4()) # Use valid UUID format # Act: Execute the method under test - result = MetadataService.delete_metadata(dataset.id, fake_metadata_id) + result = MetadataService.delete_metadata(db_session_with_containers, dataset.id, fake_metadata_id) # Assert: Verify the method returns None when metadata is not found assert result is None @@ -474,7 +500,9 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Create metadata binding binding = DatasetMetadataBinding( @@ -494,7 +522,7 @@ class TestMetadataService: db_session_with_containers.commit() # Act: Execute the method under test - result = MetadataService.delete_metadata(dataset.id, metadata.id) + result = MetadataService.delete_metadata(db_session_with_containers, dataset.id, metadata.id) # Assert: Verify the expected outcomes assert result is not None @@ -559,7 +587,7 @@ class TestMetadataService: assert dataset.built_in_field_enabled is False # Act: Execute the method under test - MetadataService.enable_built_in_field(dataset) + MetadataService.enable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the expected outcomes @@ -595,7 +623,7 @@ class TestMetadataService: ]() # Act: Execute the method under test - MetadataService.enable_built_in_field(dataset) + MetadataService.enable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the method returns early without changes db_session_with_containers.refresh(dataset) @@ -621,7 +649,7 @@ class TestMetadataService: ]() # Act: Execute the method under test - MetadataService.enable_built_in_field(dataset) + MetadataService.enable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the expected outcomes @@ -668,7 +696,7 @@ class TestMetadataService: ] # Act: Execute the method under test - MetadataService.disable_built_in_field(dataset) + MetadataService.disable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the expected outcomes db_session_with_containers.refresh(dataset) @@ -700,7 +728,7 @@ class TestMetadataService: ]() # Act: Execute the method under test - MetadataService.disable_built_in_field(dataset) + MetadataService.disable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the method returns early without changes @@ -733,7 +761,7 @@ class TestMetadataService: ]() # Act: Execute the method under test - MetadataService.disable_built_in_field(dataset) + MetadataService.disable_built_in_field(db_session_with_containers, dataset) # Assert: Verify the expected outcomes db_session_with_containers.refresh(dataset) @@ -758,7 +786,9 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Mock DocumentService.get_document mock_external_service_dependencies["document_service"].get_document.return_value = document @@ -777,7 +807,7 @@ class TestMetadataService: operation_data = MetadataOperationData(operation_data=[operation]) # Act: Execute the method under test - MetadataService.update_documents_metadata(dataset, operation_data, account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, operation_data, account) # Assert: Verify the expected outcomes @@ -822,7 +852,9 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Mock DocumentService.get_document mock_external_service_dependencies["document_service"].get_document.return_value = document @@ -841,7 +873,7 @@ class TestMetadataService: operation_data = MetadataOperationData(operation_data=[operation]) # Act: Execute the method under test - MetadataService.update_documents_metadata(dataset, operation_data, account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, operation_data, account) # Assert: Verify the expected outcomes # Verify document metadata was updated with both custom and built-in fields @@ -869,7 +901,9 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Create metadata operation data from services.entities.knowledge_entities.knowledge_entities import ( @@ -890,7 +924,7 @@ class TestMetadataService: # Act & Assert: The method should raise ValueError("Document not found.") # because the exception is now re-raised after rollback with pytest.raises(ValueError, match="Document not found"): - MetadataService.update_documents_metadata(dataset, operation_data, account) + MetadataService.update_documents_metadata(db_session_with_containers, dataset, operation_data, account) def test_knowledge_base_metadata_lock_check_dataset_id( self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps @@ -986,7 +1020,9 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Create document and metadata binding document = self._create_test_document( @@ -1005,7 +1041,7 @@ class TestMetadataService: db_session_with_containers.commit() # Act: Execute the method under test - result = MetadataService.get_dataset_metadatas(dataset) + result = MetadataService.get_dataset_metadatas(db_session_with_containers, dataset) # Assert: Verify the expected outcomes assert result is not None @@ -1045,10 +1081,12 @@ class TestMetadataService: # Create metadata metadata_args = MetadataArgs(type="string", name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) + metadata = MetadataService.create_metadata( + db_session_with_containers, dataset.id, metadata_args, account, tenant.id + ) # Act: Execute the method under test - result = MetadataService.get_dataset_metadatas(dataset) + result = MetadataService.get_dataset_metadatas(db_session_with_containers, dataset) # Assert: Verify the expected outcomes assert result is not None @@ -1077,7 +1115,7 @@ class TestMetadataService: ) # Act: Execute the method under test - result = MetadataService.get_dataset_metadatas(dataset) + result = MetadataService.get_dataset_metadatas(db_session_with_containers, dataset) # Assert: Verify the expected outcomes assert result is not None diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py index 5fa5de6d80d..0969198ecf3 100644 --- a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -4,7 +4,7 @@ from __future__ import annotations import uuid from typing import cast -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch from uuid import uuid4 import pytest @@ -172,4 +172,4 @@ class TestOAuthServerServiceTokenOperations: result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") assert result is expected_user - mock_load.assert_called_once_with("user-88") + mock_load.assert_called_once_with("user-88", ANY) diff --git a/api/tests/test_containers_integration_tests/services/test_ops_service.py b/api/tests/test_containers_integration_tests/services/test_ops_service.py index ff76bce416a..9643fb61d44 100644 --- a/api/tests/test_containers_integration_tests/services/test_ops_service.py +++ b/api/tests/test_containers_integration_tests/services/test_ops_service.py @@ -51,8 +51,9 @@ class TestOpsService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant app_service = AppService() app = app_service.create_app( diff --git a/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py b/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py index 750e35843be..9b8eec08ef4 100644 --- a/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py @@ -267,36 +267,45 @@ class TestRecommendedAppServiceGetDetail: class TestRecommendedAppServiceGetLearnDifyApps: - def test_returns_database_learn_dify_apps_without_remote_factory(self, monkeypatch: pytest.MonkeyPatch) -> None: + @patch("services.recommended_app_service.FeatureService", autospec=True) + @patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True) + @patch("services.recommended_app_service.dify_config") + def test_uses_configured_retrieval_source( + self, mock_config: MagicMock, mock_factory_class: MagicMock, mock_feature_service: MagicMock + ) -> None: + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote" + mock_feature_service.get_system_features.return_value = SimpleNamespace(enable_trial_app=False) expected_app = RecommendedAppPayload(app_id="app-1", category="Workflow") - mock_database_retrieval = MagicMock() - mock_database_retrieval.fetch_learn_dify_apps_from_db.return_value = { + mock_instance = MagicMock() + mock_instance.get_learn_dify_apps.return_value = { "recommended_apps": [expected_app], "categories": ["Workflow"], } - monkeypatch.setattr(service_module, "DatabaseRecommendAppRetrieval", mock_database_retrieval) - monkeypatch.setattr( - service_module.FeatureService, - "get_system_features", - MagicMock(return_value=SimpleNamespace(enable_trial_app=False)), - ) - factory_mock = MagicMock() - monkeypatch.setattr(service_module.RecommendAppRetrievalFactory, "get_recommend_app_factory", factory_mock) + mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) result = RecommendedAppService.get_learn_dify_apps(db.session, "en-US") assert result == {"recommended_apps": [expected_app]} - mock_database_retrieval.fetch_learn_dify_apps_from_db.assert_called_once_with("en-US") - factory_mock.assert_not_called() + mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote") + mock_instance.get_learn_dify_apps.assert_called_once_with("en-US") - def test_sets_can_trial_when_trial_feature_enabled(self, monkeypatch: pytest.MonkeyPatch) -> None: + @patch("services.recommended_app_service.dify_config") + def test_sets_can_trial_when_trial_feature_enabled( + self, mock_config: MagicMock, monkeypatch: pytest.MonkeyPatch + ) -> None: + mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db" app = RecommendedAppPayload(app_id="app-1", category="Workflow") - mock_database_retrieval = MagicMock() - mock_database_retrieval.fetch_learn_dify_apps_from_db.return_value = { + mock_retrieval_instance = MagicMock() + mock_retrieval_instance.get_learn_dify_apps.return_value = { "recommended_apps": [app], "categories": ["Workflow"], } - monkeypatch.setattr(service_module, "DatabaseRecommendAppRetrieval", mock_database_retrieval) + mock_retrieval_factory = MagicMock(return_value=mock_retrieval_instance) + monkeypatch.setattr( + service_module.RecommendAppRetrievalFactory, + "get_recommend_app_factory", + MagicMock(return_value=mock_retrieval_factory), + ) monkeypatch.setattr( service_module.FeatureService, "get_system_features", diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index ad85ac67bc5..cfd1d4e86b4 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -68,8 +68,9 @@ class TestSavedMessageService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app with realistic data diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index 3a6d635e63b..c464505ef9e 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -55,6 +55,7 @@ class TestTriggerProviderService: def _create_test_account_and_tenant( self, mock_external_service_dependencies: MockExternalServiceDependencies, + db_session_with_containers: Session, ) -> tuple[Account, Tenant]: """ Helper method to create a test account and tenant for testing. @@ -83,8 +84,9 @@ class TestTriggerProviderService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant assert tenant is not None @@ -164,7 +166,9 @@ class TestTriggerProviderService: - Database state is correctly updated """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -262,7 +266,9 @@ class TestTriggerProviderService: - Merged credentials contain only new values """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -320,7 +326,9 @@ class TestTriggerProviderService: - Original credentials are preserved """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -376,7 +384,9 @@ class TestTriggerProviderService: - UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -434,7 +444,9 @@ class TestTriggerProviderService: - Original subscription state is preserved """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY @@ -474,9 +486,8 @@ class TestTriggerProviderService: assert subscription.name == original_name assert subscription.parameters == original_parameters - @pytest.mark.usefixtures("db_session_with_containers") def test_rebuild_trigger_subscription_subscription_not_found( - self, mock_external_service_dependencies: MockExternalServiceDependencies + self, mock_external_service_dependencies: MockExternalServiceDependencies, db_session_with_containers: Session ) -> None: """ Test error when subscription is not found. @@ -485,7 +496,9 @@ class TestTriggerProviderService: - Proper error is raised when subscription doesn't exist """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") fake_subscription_id = fake.uuid4() @@ -509,7 +522,9 @@ class TestTriggerProviderService: - Error is raised when new name conflicts with existing subscription """ fake = Faker() - account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies) + account, tenant = self._create_test_account_and_tenant( + mock_external_service_dependencies, db_session_with_containers + ) provider_id = TriggerProviderID("test_org/test_plugin/test_provider") credential_type = CredentialType.API_KEY diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 8651636616c..664c1167994 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -72,8 +72,9 @@ class TestWebConversationService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app with realistic data diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 52b12293027..a7e29045e0d 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -63,8 +63,9 @@ class TestWebhookService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant assert tenant is not None diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py index 69cde847f8e..f37dc328b02 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import logging from types import SimpleNamespace from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -355,7 +356,10 @@ class TestWebhookServiceTriggerExecutionWithContainers: mock_mark_rate_limited.assert_called_once_with(tenant.id) def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors( - self, db_session_with_containers: Session, flask_app_with_containers: Flask + self, + db_session_with_containers: Session, + flask_app_with_containers: Flask, + caplog: pytest.LogCaptureFixture, ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -367,13 +371,11 @@ class TestWebhookServiceTriggerExecutionWithContainers: webhook_trigger = factory.create_webhook_trigger( db_session_with_containers, app=app, account=account, node_id="node-1" ) + caplog.set_level(logging.ERROR, logger="services.trigger.webhook_service") - with ( - patch( - "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", - side_effect=RuntimeError("boom"), - ), - patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, + with patch( + "services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", + side_effect=RuntimeError("boom"), ): with pytest.raises(RuntimeError, match="boom"): WebhookService.trigger_workflow_execution( @@ -382,7 +384,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: workflow, ) - mock_logger_exception.assert_called_once() + assert caplog.messages.count(f"Failed to trigger workflow for webhook {webhook_trigger.webhook_id}") == 1 class TestWebhookServiceRelationshipSyncWithContainers: @@ -482,7 +484,10 @@ class TestWebhookServiceRelationshipSyncWithContainers: assert cached_payload["webhook_id"] == "cache-webhook-id-00001" def test_sync_webhook_relationships_logs_when_lock_release_fails( - self, db_session_with_containers: Session, flask_app_with_containers: Flask + self, + db_session_with_containers: Session, + flask_app_with_containers: Flask, + caplog: pytest.LogCaptureFixture, ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -494,14 +499,12 @@ class TestWebhookServiceRelationshipSyncWithContainers: lock = MagicMock() lock.acquire.return_value = True lock.release.side_effect = RuntimeError("release failed") + caplog.set_level(logging.ERROR, logger="services.trigger.webhook_service") - with ( - patch("services.trigger.webhook_service.redis_client.lock", return_value=lock), - patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception, - ): + with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock): WebhookService.sync_webhook_relationships(app, workflow) - mock_logger_exception.assert_called_once() + assert caplog.messages.count(f"Failed to release lock for webhook sync, app {app.id}") == 1 def _read_cache(cache_key: str) -> dict[str, str] | None: diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index fbbf255c581..cf76afb303c 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -78,8 +78,9 @@ class TestWorkflowAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Import here to avoid circular dependency @@ -126,8 +127,9 @@ class TestWorkflowAppService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant return tenant, account diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index e065e5df1c3..726c360d77e 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -74,8 +74,9 @@ class TestWorkflowRunService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app with realistic data @@ -530,8 +531,9 @@ class TestWorkflowRunService: name="Test User", password="password123", interface_language="en-US", + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant") + TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers) tenant = account.current_tenant # Create app @@ -581,8 +583,9 @@ class TestWorkflowRunService: name="Test User", password="password123", interface_language="en-US", + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant") + TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers) tenant = account.current_tenant # Create app @@ -632,8 +635,9 @@ class TestWorkflowRunService: name="Test User", password="password123", interface_language="en-US", + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant") + TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers) tenant = account.current_tenant # Create app diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 9b574fe2dfe..6f342e63dc8 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -89,8 +89,9 @@ class TestWorkflowToolManageService: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create app with realistic data diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index ef65b905086..f6e03b84c90 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -90,8 +90,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -211,8 +212,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -255,8 +257,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Test different index types @@ -342,8 +345,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -424,8 +428,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -525,8 +530,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -625,8 +631,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -717,8 +724,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -820,8 +828,11 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, + ) + TenantService.create_owner_tenant_if_not_exist( + account, name=fake.company(), session=db_session_with_containers ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant accounts.append(account) tenants.append(tenant) @@ -926,8 +937,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset @@ -1031,8 +1043,9 @@ class TestCleanNotionDocumentTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant # Create dataset with built-in fields enabled diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index aba2458d55c..fbdd81aed5d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -67,8 +67,9 @@ class TestDealDatasetVectorIndexTask: name=fake.name(), interface_language="en-US", password=generate_valid_password(fake), + session=db_session_with_containers, ) - TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers) tenant = account.current_tenant assert tenant is not None return account, tenant diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index cb5fb5483c8..19611bd6ff2 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -6,9 +6,11 @@ using TestContainers to ensure realistic database interactions and proper isolat The task is responsible for removing document segments from the search index when they are disabled. """ +import logging from unittest.mock import MagicMock, patch from uuid import uuid4 +import pytest from faker import Faker from sqlalchemy import select from sqlalchemy.orm import Session @@ -533,7 +535,9 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value mock_factory.assert_called_with(doc_form) - def test_disable_segments_performance_timing(self, db_session_with_containers: Session): + def test_disable_segments_performance_timing( + self, db_session_with_containers: Session, caplog: pytest.LogCaptureFixture + ): """ Test that the task properly measures and logs performance timing. @@ -562,21 +566,18 @@ class TestDisableSegmentsFromIndexTask: # Mock time.perf_counter to control timing with patch("tasks.disable_segments_from_index_task.time.perf_counter") as mock_perf_counter: mock_perf_counter.side_effect = [1000.0, 1000.5] # 0.5 seconds execution time + caplog.set_level(logging.INFO, logger="tasks.disable_segments_from_index_task") - # Mock logger to capture log messages - with patch("tasks.disable_segments_from_index_task.logger") as mock_logger: - # Act - result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) + # Act + result = disable_segments_from_index_task(segment_ids, dataset.id, document.id) - # Assert - assert result is None # Task should complete without returning a value + # Assert + assert result is None # Task should complete without returning a value - # Verify performance logging - mock_logger.info.assert_called() - log_calls = [call[0][0] for call in mock_logger.info.call_args_list] - performance_log = next((call for call in log_calls if "latency" in call), None) - assert performance_log is not None - assert "0.5" in performance_log # Should log the execution time + # Verify performance logging + performance_log = next((message for message in caplog.messages if "latency" in message), None) + assert performance_log is not None + assert "0.5" in performance_log # Should log the execution time def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers: Session): """ diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index 0eec166fe2f..567bd44b930 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -10,6 +10,7 @@ All tests use the testcontainers infrastructure to ensure proper database isolat and realistic testing scenarios with actual PostgreSQL and Redis instances. """ +import logging from unittest.mock import MagicMock, patch import pytest @@ -543,7 +544,10 @@ class TestSendEmailCodeLoginMailTask: redis_client.delete(cache_key) def test_send_email_code_login_mail_task_error_handling_comprehensive( - self, db_session_with_containers: Session, mock_external_service_dependencies + self, + db_session_with_containers: Session, + mock_external_service_dependencies, + caplog: pytest.LogCaptureFixture, ): """ Test comprehensive error handling for email code login mail task. @@ -559,6 +563,7 @@ class TestSendEmailCodeLoginMailTask: test_email = fake.email() test_code = "123456" test_language = "en-US" + caplog.set_level(logging.ERROR, logger="tasks.mail_email_code_login") # Test different exception types exception_types = [ @@ -571,31 +576,26 @@ class TestSendEmailCodeLoginMailTask: for error_name, exception in exception_types: # Reset mocks for each test case + caplog.clear() mock_email_service_instance = mock_external_service_dependencies["email_service_instance"] mock_email_service_instance.reset_mock() mock_email_service_instance.send_email.side_effect = exception - # Mock logging to capture error messages - with patch("tasks.mail_email_code_login.logger", autospec=True) as mock_logger: - # Act: Execute the task - it should handle the exception gracefully - send_email_code_login_mail_task( - language=test_language, - to=test_email, - code=test_code, - ) + # Act: Execute the task - it should handle the exception gracefully + send_email_code_login_mail_task( + language=test_language, + to=test_email, + code=test_code, + ) - # Assert: Verify error handling - # Verify email service was called (and failed) - mock_email_service_instance.send_email.assert_called_once() + # Assert: Verify error handling + # Verify email service was called (and failed) + mock_email_service_instance.send_email.assert_called_once() - # Verify error was logged - error_calls = [ - call - for call in mock_logger.exception.call_args_list - if f"Send email code login mail to {test_email} failed" in str(call) - ] - # Check if any exception call was made (the exact message format may vary) - assert mock_logger.exception.call_count >= 1, f"Error should be logged for {error_name}" + # Verify error was logged + assert f"Send email code login mail to {test_email} failed" in caplog.messages, ( + f"Error should be logged for {error_name}" + ) # Reset side effect for next iteration mock_email_service_instance.send_email.side_effect = None diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py index de15d4cc772..507e674e5af 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -11,6 +11,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances. """ import json +import logging import uuid from datetime import UTC, datetime from unittest.mock import MagicMock, patch @@ -295,7 +296,10 @@ class TestMailInviteMemberTask: mock_email_service.send_email.assert_not_called() def test_send_invite_member_mail_email_service_exception( - self, db_session_with_containers: Session, mock_external_service_dependencies + self, + db_session_with_containers: Session, + mock_external_service_dependencies, + caplog: pytest.LogCaptureFixture, ): """ Test error handling when email service raises an exception. @@ -308,21 +312,19 @@ class TestMailInviteMemberTask: # Arrange: Setup email service to raise exception mock_email_service = mock_external_service_dependencies["email_service"] mock_email_service.send_email.side_effect = Exception("Email service failed") + caplog.set_level(logging.ERROR, logger="tasks.mail_invite_member_task") # Act & Assert: Execute task and verify exception is handled - with patch("tasks.mail_invite_member_task.logger", autospec=True) as mock_logger: - send_invite_member_mail_task( - language="en-US", - to="test@example.com", - token="test-token", - inviter_name="Test User", - workspace_name="Test Workspace", - ) + send_invite_member_mail_task( + language="en-US", + to="test@example.com", + token="test-token", + inviter_name="Test User", + workspace_name="Test Workspace", + ) - # Verify error was logged - mock_logger.exception.assert_called_once() - error_call = mock_logger.exception.call_args[0][0] - assert "Send invite member mail to %s failed" in error_call + # Verify error was logged + assert caplog.messages.count("Send invite member mail to test@example.com failed") == 1 def test_send_invite_member_mail_template_context_validation( self, db_session_with_containers: Session, mock_external_service_dependencies diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py index c74b451b4b6..685c1617bb0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py @@ -5,6 +5,7 @@ This module provides integration tests for email registration tasks using TestContainers to ensure real database and service interactions. """ +import logging from unittest.mock import patch import pytest @@ -68,7 +69,10 @@ class TestMailRegisterTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_email_register_mail_task_exception_handling( - self, db_session_with_containers: Session, mock_mail_dependencies + self, + db_session_with_containers: Session, + mock_mail_dependencies, + caplog: pytest.LogCaptureFixture, ): """Test email registration task exception handling.""" mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") @@ -76,10 +80,11 @@ class TestMailRegisterTask: fake = Faker() to_email = fake.email() code = fake.numerify("######") + caplog.set_level(logging.ERROR, logger="tasks.mail_register_task") - with patch("tasks.mail_register_task.logger", autospec=True) as mock_logger: - send_email_register_mail_task(language="en-US", to=to_email, code=code) - mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) + send_email_register_mail_task(language="en-US", to=to_email, code=code) + + assert caplog.messages.count(f"Send email register mail to {to_email} failed") == 1 def test_send_email_register_mail_task_when_account_exist_success( self, db_session_with_containers: Session, mock_mail_dependencies @@ -121,7 +126,10 @@ class TestMailRegisterTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_email_register_mail_task_when_account_exist_exception_handling( - self, db_session_with_containers: Session, mock_mail_dependencies + self, + db_session_with_containers: Session, + mock_mail_dependencies, + caplog: pytest.LogCaptureFixture, ): """Test account exist email task exception handling.""" mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") @@ -129,7 +137,8 @@ class TestMailRegisterTask: fake = Faker() to_email = fake.email() account_name = fake.name() + caplog.set_level(logging.ERROR, logger="tasks.mail_register_task") - with patch("tasks.mail_register_task.logger", autospec=True) as mock_logger: - send_email_register_mail_task_when_account_exist(language="en-US", to=to_email, account_name=account_name) - mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) + send_email_register_mail_task_when_account_exist(language="en-US", to=to_email, account_name=account_name) + + assert caplog.messages.count(f"Send email register mail to {to_email} failed") == 1 diff --git a/api/tests/unit_tests/controllers/common/test_fields.py b/api/tests/unit_tests/controllers/common/test_fields.py index d4dc13127d6..595b3d93356 100644 --- a/api/tests/unit_tests/controllers/common/test_fields.py +++ b/api/tests/unit_tests/controllers/common/test_fields.py @@ -1,5 +1,4 @@ import builtins -from types import SimpleNamespace from unittest.mock import patch from flask.views import MethodView as FlaskMethodView @@ -22,7 +21,7 @@ def test_parameters_model_round_trip(): def test_site_icon_url_uses_signed_url_for_image_icon(): - site = SimpleNamespace( + site = Site( title="Example", chat_color_theme=None, chat_color_theme_inverted=False, @@ -46,7 +45,7 @@ def test_site_icon_url_uses_signed_url_for_image_icon(): def test_site_icon_url_is_none_for_non_image_icon(): - site = SimpleNamespace( + site = Site( title="Example", chat_color_theme=None, chat_color_theme_inverted=False, diff --git a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py index 27bb75e21f8..3d84f899379 100644 --- a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py +++ b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py @@ -15,6 +15,7 @@ from controllers.console.agent.composer import ( AgentComposerValidateApi, WorkflowAgentComposerApi, WorkflowAgentComposerCandidatesApi, + WorkflowAgentComposerCopyFromRosterApi, WorkflowAgentComposerImpactApi, WorkflowAgentComposerSaveToRosterApi, WorkflowAgentComposerValidateApi, @@ -1017,6 +1018,58 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save( )["save_options"] == ["node_job_only"] +def test_workflow_composer_copy_from_roster(app: Flask, monkeypatch: pytest.MonkeyPatch, account_id: str) -> None: + app_model = SimpleNamespace(id="app-1") + captured: dict[str, object] = {} + + def fake_copy_from_roster(**kwargs): + captured.update(kwargs) + return _workflow_composer_response( + binding={ + "id": "binding-1", + "binding_type": "inline_agent", + "agent_id": "inline-agent-1", + "current_snapshot_id": "inline-version-1", + "workflow_id": "workflow-1", + "node_id": kwargs["node_id"], + }, + agent={ + "id": "inline-agent-1", + "name": "Nadia", + "description": "", + "scope": "workflow_only", + "status": "active", + }, + active_config_snapshot={"id": "inline-version-1", "version": 1}, + ) + + monkeypatch.setattr( + composer_controller.AgentComposerService, "copy_workflow_composer_from_roster", fake_copy_from_roster + ) + + with app.test_request_context( + json={ + "source_agent_id": "roster-agent-1", + "source_snapshot_id": "roster-version-1", + "idempotency_key": "copy-1", + } + ): + result = unwrap(WorkflowAgentComposerCopyFromRosterApi.post)( + WorkflowAgentComposerCopyFromRosterApi(), "tenant-1", account_id, app_model, "node-1" + ) + + assert result["binding"]["binding_type"] == "inline_agent" + assert captured == { + "tenant_id": "tenant-1", + "app_id": "app-1", + "node_id": "node-1", + "account_id": account_id, + "source_agent_id": "roster-agent-1", + "source_snapshot_id": "roster-version-1", + "idempotency_key": "copy-1", + } + + def test_workflow_impact_returns_empty_without_version(app: Flask) -> None: payload = {"variant": ComposerVariant.WORKFLOW.value, "save_strategy": ComposerSaveStrategy.NODE_JOB_ONLY.value} diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_convert_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_convert_api.py index dd254a31f63..942698db3bb 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_convert_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_convert_api.py @@ -2,21 +2,14 @@ from __future__ import annotations +from inspect import unwrap from types import SimpleNamespace import pytest from flask import Flask from controllers.console.app import workflow as workflow_module - - -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func +from controllers.console.app.workflow import ConvertToWorkflowApi class TestConvertToWorkflowApi: @@ -25,9 +18,9 @@ class TestConvertToWorkflowApi: return workflow_module.ConvertToWorkflowApi() def test_convert_to_workflow_attaches_permission_keys_when_rbac_enabled( - self, api, app: Flask, monkeypatch: pytest.MonkeyPatch + self, api: ConvertToWorkflowApi, app: Flask, monkeypatch: pytest.MonkeyPatch ) -> None: - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr( workflow_module, @@ -46,6 +39,7 @@ class TestConvertToWorkflowApi: json={}, ): response = method( + api, current_tenant_id="tenant-1", current_user=SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"), diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index 6ef7f442591..ebae7de6c15 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -8,7 +8,7 @@ This module tests the account activation mechanism including: - Initial login after activation """ -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -41,7 +41,7 @@ class TestActivateCheckApi: } @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation): + def test_check_valid_invitation_token(self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock): """ Test checking valid invitation token. @@ -67,7 +67,9 @@ class TestActivateCheckApi: assert response["data"]["email"] == "invitee@example.com" @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_valid_invitation_token_includes_account_status(self, mock_get_invitation, app, mock_invitation): + def test_check_valid_invitation_token_includes_account_status( + self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock + ): mock_account = MagicMock() mock_account.status = AccountStatus.ACTIVE mock_invitation["account"] = mock_account @@ -103,7 +105,9 @@ class TestActivateCheckApi: assert response["is_valid"] is False @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation): + def test_check_token_without_workspace_id( + self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock + ): """ Test checking token without workspace ID. @@ -121,10 +125,10 @@ class TestActivateCheckApi: # Assert assert response["is_valid"] is True - mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token") + mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token", session=ANY) @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation): + def test_check_token_without_email(self, mock_get_invitation: MagicMock, app: Flask, mock_invitation): """ Test checking token without email parameter. @@ -142,10 +146,12 @@ class TestActivateCheckApi: # Assert assert response["is_valid"] is True - mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token") + mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token", session=ANY) @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation): + def test_check_token_normalizes_email_to_lowercase( + self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock + ): """Ensure token validation uses lowercase emails.""" mock_get_invitation.return_value = mock_invitation @@ -156,7 +162,7 @@ class TestActivateCheckApi: response = api.get() assert response["is_valid"] is True - mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token", session=ANY) class TestActivateApi: @@ -554,7 +560,7 @@ class TestActivateApi: response = api.post() assert response["result"] == "success" - mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token") + mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token", session=ANY) mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") @patch("controllers.console.auth.activate.TenantService.create_tenant_member") @@ -593,7 +599,7 @@ class TestActivateApi: mock_create_tenant_member.assert_called_once_with( mock_invitation["tenant"], mock_account, mock_db.session, role=TenantAccountRole.ADMIN ) - mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id) + mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id, session=ANY) mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") @patch("controllers.console.auth.activate.TenantService.create_tenant_member") @@ -628,5 +634,5 @@ class TestActivateApi: assert response["result"] == "success" mock_create_tenant_member.assert_not_called() - mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id) + mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id, session=ANY) mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token") diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py index 7f449bb376e..21d1932f820 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import UTC, datetime from inspect import unwrap from types import SimpleNamespace -from unittest.mock import PropertyMock, patch +from unittest.mock import ANY, PropertyMock, patch from controllers.console import console_ns from controllers.console.auth.data_source_bearer_auth import ( @@ -34,13 +34,16 @@ def test_list_data_source_auth_uses_injected_tenant_id() -> None: updated_at=datetime(2026, 1, 2, tzinfo=UTC), ) - with patch( - "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list", - return_value=[binding], - ) as get_provider_auth_list: + with ( + patch("controllers.console.auth.data_source_bearer_auth.db"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list", + return_value=[binding], + ) as get_provider_auth_list, + ): result = method(api, "tenant-1") - get_provider_auth_list.assert_called_once_with("tenant-1") + get_provider_auth_list.assert_called_once_with(ANY, "tenant-1") assert result["sources"][0]["id"] == "binding-1" assert result["sources"][0]["provider"] == "custom" @@ -56,12 +59,13 @@ def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None: with ( _payload_patch(payload), + patch("controllers.console.auth.data_source_bearer_auth.db"), patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth, ): result, status = method(api, "tenant-1") - create_auth.assert_called_once_with("tenant-1", payload) + create_auth.assert_called_once_with(ANY, "tenant-1", payload) assert result == {"result": "success"} assert status == 200 @@ -70,11 +74,14 @@ def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None: api = ApiKeyAuthDataSourceBindingDelete() method = unwrap(api.delete) - with patch( - "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth" - ) as delete_provider_auth: + with ( + patch("controllers.console.auth.data_source_bearer_auth.db"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth" + ) as delete_provider_auth, + ): result, status = method(api, "tenant-1", "binding-1") - delete_provider_auth.assert_called_once_with("tenant-1", "binding-1") + delete_provider_auth.assert_called_once_with(ANY, "tenant-1", "binding-1") assert result == "" assert status == 204 diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py b/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py index df282880af0..14af718590a 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_register_language.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from pydantic import ValidationError @@ -25,6 +25,7 @@ def test_create_new_account_uses_requested_language(mock_create_account): password="ValidPass123!", interface_language="zh-Hans", timezone="Asia/Shanghai", + session=ANY, ) diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index fa23942c653..e56ed48bcfe 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -9,7 +9,7 @@ This module tests the email code login mechanism including: """ import base64 -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -368,6 +368,7 @@ class TestEmailCodeLoginApi: name="newuser@example.com", interface_language="en-US", timezone="Asia/Shanghai", + session=ANY, ) @patch("controllers.console.wraps.db") diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 92656357d4d..42495d7f63c 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -9,7 +9,8 @@ This module tests the core authentication endpoints including: """ import base64 -from unittest.mock import MagicMock, Mock, patch +import logging +from unittest.mock import ANY, MagicMock, Mock, patch import pytest from flask import Flask @@ -129,7 +130,7 @@ class TestLoginApi: response = login_api.post() # Assert - mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None) + mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None, session=ANY) mock_login.assert_called_once() mock_reset_rate_limit.assert_called_once_with("test@example.com") assert response.json["result"] == "success" @@ -184,14 +185,16 @@ class TestLoginApi: response = login_api.post() # Assert - mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token") + mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token", session=ANY) assert response.json["result"] == "success" @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") - def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask): + def test_login_fails_when_rate_limited( + self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask, caplog: pytest.LogCaptureFixture + ): """ Test login rejection when rate limit is exceeded. @@ -204,22 +207,26 @@ class TestLoginApi: mock_get_invitation.return_value = None # Act & Assert - with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: - with app.test_request_context( - "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} - ): - login_api = LoginApi() - with pytest.raises(EmailPasswordLoginLimitError): - login_api.post() + with app.test_request_context( + "/login", method="POST", json={"email": "test@example.com", "password": encode_password("password")} + ): + login_api = LoginApi() + with pytest.raises(EmailPasswordLoginLimitError): + login_api.post() - assert mock_log_warning.call_count == 1 - assert mock_log_warning.call_args.args[1] == "test@example.com" - assert mock_log_warning.call_args.args[2] == LoginFailureReason.LOGIN_RATE_LIMITED + warn_records = [ + r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING + ] + assert len(warn_records) == 1 + assert warn_records[0].args[0] == "test@example.com" + assert warn_records[0].args[1] == LoginFailureReason.LOGIN_RATE_LIMITED @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True) @patch("controllers.console.auth.login.BillingService.is_email_in_freeze") - def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask): + def test_login_fails_when_account_frozen( + self, mock_is_frozen, mock_db, app: Flask, caplog: pytest.LogCaptureFixture + ): """ Test login rejection for frozen accounts. @@ -231,17 +238,19 @@ class TestLoginApi: mock_is_frozen.return_value = True # Act & Assert - with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: - with app.test_request_context( - "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} - ): - login_api = LoginApi() - with pytest.raises(AccountInFreezeError): - login_api.post() + with app.test_request_context( + "/login", method="POST", json={"email": "frozen@example.com", "password": encode_password("password")} + ): + login_api = LoginApi() + with pytest.raises(AccountInFreezeError): + login_api.post() - assert mock_log_warning.call_count == 1 - assert mock_log_warning.call_args.args[1] == "frozen@example.com" - assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_IN_FREEZE + warn_records = [ + r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING + ] + assert len(warn_records) == 1 + assert warn_records[0].args[0] == "frozen@example.com" + assert warn_records[0].args[1] == LoginFailureReason.ACCOUNT_IN_FREEZE @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -257,6 +266,7 @@ class TestLoginApi: mock_is_rate_limit, mock_db, app: Flask, + caplog: pytest.LogCaptureFixture, ): """ Test login failure with invalid credentials. @@ -272,20 +282,22 @@ class TestLoginApi: mock_authenticate.side_effect = AccountPasswordError("Invalid password") # Act & Assert - with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: - with app.test_request_context( - "/login", - method="POST", - json={"email": "test@example.com", "password": encode_password("WrongPass123!")}, - ): - login_api = LoginApi() - with pytest.raises(AuthenticationFailedError): - login_api.post() + with app.test_request_context( + "/login", + method="POST", + json={"email": "test@example.com", "password": encode_password("WrongPass123!")}, + ): + login_api = LoginApi() + with pytest.raises(AuthenticationFailedError): + login_api.post() mock_add_rate_limit.assert_called_once_with("test@example.com") - assert mock_log_warning.call_count == 1 - assert mock_log_warning.call_args.args[1] == "test@example.com" - assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS + warn_records = [ + r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING + ] + assert len(warn_records) == 1 + assert warn_records[0].args[0] == "test@example.com" + assert warn_records[0].args[1] == LoginFailureReason.INVALID_CREDENTIALS @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -293,7 +305,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") def test_login_fails_for_banned_account( - self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask + self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask, caplog ): """ Test login rejection for banned accounts. @@ -308,19 +320,21 @@ class TestLoginApi: mock_authenticate.side_effect = AccountLoginError("Account is banned") # Act & Assert - with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: - with app.test_request_context( - "/login", - method="POST", - json={"email": "banned@example.com", "password": encode_password("ValidPass123!")}, - ): - login_api = LoginApi() - with pytest.raises(AccountBannedError): - login_api.post() + with app.test_request_context( + "/login", + method="POST", + json={"email": "banned@example.com", "password": encode_password("ValidPass123!")}, + ): + login_api = LoginApi() + with pytest.raises(AccountBannedError): + login_api.post() - assert mock_log_warning.call_count == 1 - assert mock_log_warning.call_args.args[1] == "banned@example.com" - assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED + warn_records = [ + r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING + ] + assert len(warn_records) == 1 + assert warn_records[0].args[0] == "banned@example.com" + assert warn_records[0].args[1] == LoginFailureReason.ACCOUNT_BANNED @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @@ -407,13 +421,13 @@ class TestLoginApi: @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit") def test_login_retries_with_lowercase_email( self, - mock_reset_rate_limit, - mock_login_service, - mock_get_tenants, - mock_add_rate_limit, - mock_authenticate, - mock_get_invitation, - mock_is_rate_limit, + mock_reset_rate_limit: MagicMock, + mock_login_service: MagicMock, + mock_get_tenants: MagicMock, + mock_add_rate_limit: MagicMock, + mock_authenticate: MagicMock, + mock_get_invitation: MagicMock, + mock_is_rate_limit: MagicMock, mock_db, app: Flask, mock_account, @@ -435,8 +449,8 @@ class TestLoginApi: assert response.json["result"] == "success" assert mock_authenticate.call_args_list == [ - (("Upper@Example.com", "ValidPass123!", None), {}), - (("upper@example.com", "ValidPass123!", None), {}), + (("Upper@Example.com", "ValidPass123!", None), {"session": ANY}), + (("upper@example.com", "ValidPass123!", None), {"session": ANY}), ] mock_add_rate_limit.assert_not_called() mock_reset_rate_limit.assert_called_once_with("upper@example.com") @@ -447,28 +461,31 @@ class TestLoginApi: @patch("controllers.console.auth.login._get_account_with_case_fallback") def test_email_code_login_logs_banned_account( self, - mock_get_account, - mock_revoke_token, - mock_get_token_data, - mock_db, + mock_get_account: MagicMock, + mock_revoke_token: MagicMock, + mock_get_token_data: MagicMock, + mock_db: MagicMock, app: Flask, + caplog: pytest.LogCaptureFixture, ): mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} mock_get_account.side_effect = Unauthorized("Account is banned.") - with patch("controllers.console.auth.login.logger.warning") as mock_log_warning: - with app.test_request_context( - "/email-code-login/validity", - method="POST", - json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, - ): - with pytest.raises(AccountBannedError): - EmailCodeLoginApi().post() + with app.test_request_context( + "/email-code-login/validity", + method="POST", + json={"email": "User@Example.com", "code": encode_code("123456"), "token": "token-123"}, + ): + with pytest.raises(AccountBannedError): + EmailCodeLoginApi().post() mock_revoke_token.assert_called_once_with("token-123") - assert mock_log_warning.call_count == 1 - assert mock_log_warning.call_args.args[1] == "user@example.com" - assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED + warn_records = [ + r for r in caplog.records if r.name == "controllers.console.auth.login" and r.levelno == logging.WARNING + ] + assert len(warn_records) == 1 + assert warn_records[0].args[0] == "user@example.com" + assert warn_records[0].args[1] == LoginFailureReason.ACCOUNT_BANNED class TestLogoutApi: @@ -491,7 +508,9 @@ class TestLogoutApi: @patch("controllers.console.auth.login.AccountService.logout") @patch("controllers.console.auth.login.flask_login.logout_user") - def test_successful_logout(self, mock_logout_user, mock_service_logout, app: Flask, mock_account): + def test_successful_logout( + self, mock_logout_user: MagicMock, mock_service_logout: MagicMock, app: Flask, mock_account + ): """ Test successful logout flow. diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py index 36c707dbf9b..8b3de6a39e5 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_timezone.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -66,8 +66,9 @@ def test_generate_account_registers_with_browser_timezone( provider="github", language="zh-Hans", timezone="Asia/Shanghai", + session=ANY, ) - mock_link_account.assert_called_once_with("github", "github-123", account) + mock_link_account.assert_called_once_with("github", "github-123", account, session=ANY) @patch("controllers.console.auth.oauth.AccountService.link_account_integrate") @@ -97,8 +98,9 @@ def test_generate_account_prefers_state_language_over_accept_language( provider="github", language="zh-Hans", timezone=None, + session=ANY, ) - mock_link_account.assert_called_once_with("github", "github-123", account) + mock_link_account.assert_called_once_with("github", "github-123", account, session=ANY) @patch("controllers.console.auth.oauth.dify_config") diff --git a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py index ba69f4d6a78..34fff57b0ad 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py +++ b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py @@ -8,7 +8,7 @@ This module tests the token refresh mechanism including: - Error handling for invalid tokens """ -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -70,7 +70,7 @@ class TestRefreshTokenApi: # Assert mock_extract_token.assert_called_once() - mock_refresh_token.assert_called_once_with("valid_refresh_token") + mock_refresh_token.assert_called_once_with("valid_refresh_token", session=ANY) assert response.json["result"] == "success" @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @@ -191,7 +191,7 @@ class TestRefreshTokenApi: # Assert assert response.json["result"] == "success" # Verify new token pair was generated - mock_refresh_token.assert_called_once_with("valid_refresh_token") + mock_refresh_token.assert_called_once_with("valid_refresh_token", session=ANY) # In real implementation, cookies would be set with new values assert mock_token_pair.access_token == "new_access_token" assert mock_token_pair.refresh_token == "new_refresh_token" diff --git a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py index e6bee6fe1d3..03a6fdb0d60 100644 --- a/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/snippets/test_snippet_workflow_draft_variable.py @@ -1,3 +1,4 @@ +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock @@ -10,12 +11,6 @@ from models.account import Account, AccountStatus from services.workflow_draft_variable_service import WorkflowDraftVariableList -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def _make_account() -> Account: account = Account( name="tester", @@ -66,7 +61,7 @@ def test_ensure_snippet_draft_variable_row_allowed_accepts_canvas_node_variable( def test_conversation_variables_returns_empty_list(app: Flask): api = module.SnippetConversationVariableCollectionApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/"): result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) @@ -76,7 +71,7 @@ def test_conversation_variables_returns_empty_list(app: Flask): def test_system_variables_returns_empty_list(app: Flask): api = module.SnippetSystemVariableCollectionApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/"): result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) @@ -91,7 +86,7 @@ def test_delete_variable_collection_deletes_current_user_variables(app: Flask, m db_session.return_value = SimpleNamespace() monkeypatch.setattr(module.db, "session", db_session) api = module.SnippetWorkflowVariableCollectionApi() - handler = _unwrap(api.delete) + handler = unwrap(api.delete) with app.test_request_context("/", method="DELETE"): response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) @@ -109,7 +104,7 @@ def test_variable_collection_get_raises_when_draft_workflow_missing(app: Flask, ) api = module.SnippetWorkflowVariableCollectionApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/?page=1&limit=20"): with pytest.raises(module.DraftWorkflowNotExist): @@ -140,7 +135,7 @@ def test_node_variable_collection_get_lists_node_variables(app: Flask, monkeypat ) api = module.SnippetNodeVariableCollectionApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/"): result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1") @@ -158,7 +153,7 @@ def test_node_variable_collection_delete_deletes_node_variables(app: Flask, monk monkeypatch.setattr(module.db, "session", db_session) api = module.SnippetNodeVariableCollectionApi() - handler = _unwrap(api.delete) + handler = unwrap(api.delete) with app.test_request_context("/", method="DELETE"): response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), node_id="llm-1") @@ -177,7 +172,7 @@ def test_variable_patch_returns_variable_when_no_changes(app: Flask, monkeypatch monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service)) api = module.SnippetVariableApi() - handler = _unwrap(api.patch) + handler = unwrap(api.patch) with app.test_request_context("/", method="PATCH", json={}): result = handler( @@ -202,7 +197,7 @@ def test_variable_delete_deletes_variable(app: Flask, monkeypatch: pytest.Monkey monkeypatch.setattr(module, "WorkflowDraftVariableService", Mock(return_value=draft_var_service)) api = module.SnippetVariableApi() - handler = _unwrap(api.delete) + handler = unwrap(api.delete) with app.test_request_context("/", method="DELETE"): response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1") @@ -230,7 +225,7 @@ def test_variable_reset_returns_no_content_when_reset_result_is_none(app: Flask, ) api = module.SnippetVariableResetApi() - handler = _unwrap(api.put) + handler = unwrap(api.put) with app.test_request_context("/", method="PUT"): response = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1"), variable_id="var-1") @@ -260,7 +255,7 @@ def test_environment_variables_returns_workflow_environment_variables(app: Flask ) api = module.SnippetEnvironmentVariableCollectionApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/"): result = handler(api, _make_account(), snippet=SimpleNamespace(id="snippet-1")) diff --git a/api/tests/unit_tests/controllers/console/test_feature.py b/api/tests/unit_tests/controllers/console/test_feature.py index 3e804583a6e..5b524cf9c64 100644 --- a/api/tests/unit_tests/controllers/console/test_feature.py +++ b/api/tests/unit_tests/controllers/console/test_feature.py @@ -94,7 +94,7 @@ class TestSystemFeatureApi: "controllers.console.feature.current_account_with_tenant_optional", return_value=(account, "tenant-123"), ) - system_features = SystemFeatureModel(is_allow_register=True) + system_features = SystemFeatureModel(is_allow_register=True, enable_learn_app=True) get_system_features = mocker.patch( "controllers.console.feature.FeatureService.get_system_features", return_value=system_features, @@ -104,6 +104,7 @@ class TestSystemFeatureApi: result = api.get() assert result == system_features.model_dump() + assert result["enable_learn_app"] is True current_account.assert_called_once_with() get_system_features.assert_called_once_with(is_authenticated=True) diff --git a/api/tests/unit_tests/controllers/console/test_init_validate.py b/api/tests/unit_tests/controllers/console/test_init_validate.py index 4954e0dc96a..88d41fa2bd0 100644 --- a/api/tests/unit_tests/controllers/console/test_init_validate.py +++ b/api/tests/unit_tests/controllers/console/test_init_validate.py @@ -38,7 +38,7 @@ def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None: def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") - monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1) + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 1) app.secret_key = "test-secret" with app.test_request_context("/console/api/init", method="POST"): @@ -48,7 +48,7 @@ def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.Mo def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") - monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0) + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 0) monkeypatch.setenv("INIT_PASSWORD", "expected") app.secret_key = "test-secret" @@ -60,7 +60,7 @@ def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.M def test_validate_init_password_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") - monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0) + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 0) monkeypatch.setenv("INIT_PASSWORD", "expected") app.secret_key = "test-secret" diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index e419428ca66..5f36e805baa 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -1,6 +1,6 @@ import inspect from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -394,12 +394,12 @@ class TestChangeEmailReset: @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_normalize_new_email_before_update( self, - mock_is_freeze, - mock_check_unique, - mock_get_data, - mock_revoke_token, - mock_update_account, - mock_send_notify, + mock_is_freeze: MagicMock, + mock_check_unique: MagicMock, + mock_get_data: MagicMock, + mock_revoke_token: MagicMock, + mock_update_account: MagicMock, + mock_send_notify: MagicMock, app: Flask, ): current_user = _build_account("old@example.com", "acc3") @@ -424,9 +424,9 @@ class TestChangeEmailReset: method(api, current_user) mock_is_freeze.assert_called_once_with("new@example.com") - mock_check_unique.assert_called_once_with("new@example.com") + mock_check_unique.assert_called_once_with("new@example.com", session=ANY) mock_revoke_token.assert_called_once_with("token-123") - mock_update_account.assert_called_once_with(current_user, email="new@example.com") + mock_update_account.assert_called_once_with(current_user, email="new@example.com", session=ANY) mock_send_notify.assert_called_once_with(email="new@example.com") @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") @@ -437,12 +437,12 @@ class TestChangeEmailReset: @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_reject_reset_when_token_phase_is_not_new_verified( self, - mock_is_freeze, - mock_check_unique, - mock_get_data, - mock_revoke_token, - mock_update_account, - mock_send_notify, + mock_is_freeze: MagicMock, + mock_check_unique: MagicMock, + mock_get_data: MagicMock, + mock_revoke_token: MagicMock, + mock_update_account: MagicMock, + mock_send_notify: MagicMock, app: Flask, ): """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" @@ -480,12 +480,12 @@ class TestChangeEmailReset: @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_reject_reset_when_token_email_differs_from_payload_new_email( self, - mock_is_freeze, - mock_check_unique, - mock_get_data, - mock_revoke_token, - mock_update_account, - mock_send_notify, + mock_is_freeze: MagicMock, + mock_check_unique: MagicMock, + mock_get_data: MagicMock, + mock_revoke_token: MagicMock, + mock_update_account: MagicMock, + mock_send_notify: MagicMock, app: Flask, ): """A verified token for address A must not be replayed to change to address B.""" @@ -523,12 +523,12 @@ class TestChangeEmailReset: @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") def test_should_reject_reset_when_token_account_id_does_not_match_current_user( self, - mock_is_freeze, - mock_check_unique, - mock_get_data, - mock_revoke_token, - mock_update_account, - mock_send_notify, + mock_is_freeze: MagicMock, + mock_check_unique: MagicMock, + mock_get_data: MagicMock, + mock_revoke_token: MagicMock, + mock_update_account: MagicMock, + mock_send_notify: MagicMock, app: Flask, ): from controllers.console.auth.error import InvalidTokenError @@ -575,9 +575,9 @@ class TestAccountServiceSendChangeEmailEmail: @patch("services.account_service.AccountService.generate_change_email_token") def test_should_bind_account_id_and_target_email_into_generated_token( self, - mock_generate_token, - mock_rate_limiter, - mock_mail_task, + mock_generate_token: MagicMock, + mock_rate_limiter: MagicMock, + mock_mail_task: MagicMock, ): mock_rate_limiter.is_rate_limited.return_value = False mock_generate_token.return_value = "the-token" @@ -665,7 +665,7 @@ class TestAccountDeletionFeedback: class TestCheckEmailUnique: @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") - def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, app: Flask): + def test_should_normalize_email(self, mock_is_freeze: MagicMock, mock_check_unique: MagicMock, app: Flask): mock_is_freeze.return_value = False mock_check_unique.return_value = True @@ -680,7 +680,7 @@ class TestCheckEmailUnique: assert response == {"result": "success"} mock_is_freeze.assert_called_once_with("case@test.com") - mock_check_unique.assert_called_once_with("case@test.com") + mock_check_unique.assert_called_once_with("case@test.com", session=ANY) def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): @@ -692,12 +692,7 @@ def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): second.scalar_one_or_none.return_value = expected_account mock_session.execute.side_effect = [first, second] - mock_factory = MagicMock() - mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False) - - with patch("services.account_service.session_factory", mock_factory): - result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com") + result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Mixed@Test.com") assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/workspace/test_rbac.py b/api/tests/unit_tests/controllers/console/workspace/test_rbac.py index d78bc1fc6dd..2960bfef324 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_rbac.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_rbac.py @@ -201,10 +201,10 @@ class TestPaginationMapping: }, ] assert response["pagination"] == { - "total_count": 5, + "total_count": 4, "per_page": 2, "current_page": 1, - "total_pages": 3, + "total_pages": 2, } mock_list.assert_not_called() diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py index 25ae0778d4b..a6626adc420 100644 --- a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -8,7 +8,7 @@ handler tests use inspect.unwrap() to bypass them and focus on business logic. import inspect from datetime import datetime -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -115,7 +115,7 @@ class TestEnterpriseWorkspace: assert result["message"] == "enterprise workspace created." assert result["tenant"]["id"] == "tenant-id" assert result["tenant"]["name"] == "My Workspace" - mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) + mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True, session=ANY) mock_tenant_svc.create_tenant_member.assert_called_once_with( mock_tenant, mock_account, mock_db.session, role="owner" ) @@ -183,5 +183,5 @@ class TestEnterpriseWorkspaceNoOwnerEmail: assert result["tenant"]["id"] == "tenant-id" assert result["tenant"]["encrypt_public_key"] == "pub-key" assert result["tenant"]["custom_config"] == {} - mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) + mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True, session=ANY) mock_event.send.assert_called_once_with(mock_tenant) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 1cfe152c864..52d050ff55a 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -176,6 +176,7 @@ class TestAudioServiceMockedBehavior: result = AudioService.transcript_tts( app_model=mock_app, + session=Mock(), text="Hello world", voice="nova", end_user="user_123", diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py index 7a9978e742a..b77c783ae16 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py @@ -17,7 +17,7 @@ Decorator strategy: import uuid from inspect import unwrap -from unittest.mock import Mock, patch +from unittest.mock import ANY, Mock, patch import pytest from flask import Flask @@ -408,7 +408,7 @@ class TestDatasetMetadataBuiltInFieldAction: assert status == 200 assert response["result"] == "success" - mock_meta_svc.enable_built_in_field.assert_called_once_with(mock_dataset) + mock_meta_svc.enable_built_in_field.assert_called_once_with(ANY, mock_dataset) @patch("controllers.service_api.dataset.metadata.MetadataService") @patch("controllers.service_api.dataset.metadata.DatasetService") @@ -439,7 +439,7 @@ class TestDatasetMetadataBuiltInFieldAction: ) assert status == 200 - mock_meta_svc.disable_built_in_field.assert_called_once_with(mock_dataset) + mock_meta_svc.disable_built_in_field.assert_called_once_with(ANY, mock_dataset) @patch("controllers.service_api.dataset.metadata.DatasetService") def test_action_dataset_not_found( diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index 839939367c4..bfffd5cbb2c 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -1,4 +1,5 @@ import base64 +import logging from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -16,6 +17,13 @@ def encode_code(code: str) -> str: return base64.b64encode(code.encode("utf-8")).decode() +def assert_login_failure_logged(caplog: pytest.LogCaptureFixture, email: str, reason: LoginFailureReason) -> None: + records = [record for record in caplog.records if record.name == "controllers.web.login"] + assert len(records) == 1 + assert records[0].args[0] == email + assert records[0].args[1] == reason + + @pytest.fixture def app(): flask_app = Flask(__name__) @@ -114,10 +122,10 @@ class TestLoginApi: "controllers.web.login.WebAppAuthService.authenticate", side_effect=services.errors.account.AccountLoginError(), ) - def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None: + def test_login_banned_account(self, mock_auth: MagicMock, app: Flask, caplog: pytest.LogCaptureFixture) -> None: from controllers.console.error import AccountBannedError - with patch("controllers.web.login.logger.warning") as mock_log_warning: + with caplog.at_level(logging.WARNING, logger="controllers.web.login"): with app.test_request_context( "/web/login", method="POST", @@ -126,18 +134,16 @@ class TestLoginApi: with pytest.raises(AccountBannedError): LoginApi().post() - assert mock_log_warning.call_count == 1 - assert mock_log_warning.call_args.args[1] == "user@example.com" - assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED + assert_login_failure_logged(caplog, "user@example.com", LoginFailureReason.ACCOUNT_BANNED) @patch( "controllers.web.login.WebAppAuthService.authenticate", side_effect=services.errors.account.AccountPasswordError(), ) - def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None: + def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask, caplog: pytest.LogCaptureFixture) -> None: from controllers.console.auth.error import AuthenticationFailedError - with patch("controllers.web.login.logger.warning") as mock_log_warning: + with caplog.at_level(logging.WARNING, logger="controllers.web.login"): with app.test_request_context( "/web/login", method="POST", @@ -146,18 +152,16 @@ class TestLoginApi: with pytest.raises(AuthenticationFailedError): LoginApi().post() - assert mock_log_warning.call_count == 1 - assert mock_log_warning.call_args.args[1] == "user@example.com" - assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_CREDENTIALS + assert_login_failure_logged(caplog, "user@example.com", LoginFailureReason.INVALID_CREDENTIALS) @patch( "controllers.web.login.WebAppAuthService.authenticate", side_effect=services.errors.account.AccountNotFoundError(), ) - def test_login_account_not_found(self, mock_auth: MagicMock, app: Flask) -> None: + def test_login_account_not_found(self, mock_auth: MagicMock, app: Flask, caplog: pytest.LogCaptureFixture) -> None: from controllers.console.auth.error import AuthenticationFailedError - with patch("controllers.web.login.logger.warning") as mock_log_warning: + with caplog.at_level(logging.WARNING, logger="controllers.web.login"): with app.test_request_context( "/web/login", method="POST", @@ -166,13 +170,13 @@ class TestLoginApi: with pytest.raises(AuthenticationFailedError): LoginApi().post() - assert mock_log_warning.call_count == 1 - assert mock_log_warning.call_args.args[1] == "missing@example.com" - assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_NOT_FOUND + assert_login_failure_logged(caplog, "missing@example.com", LoginFailureReason.ACCOUNT_NOT_FOUND) @patch("controllers.web.login.WebAppAuthService.get_email_code_login_data", return_value=None) - def test_email_code_login_logs_invalid_token(self, mock_get_token_data: MagicMock, app: Flask) -> None: - with patch("controllers.web.login.logger.warning") as mock_log_warning: + def test_email_code_login_logs_invalid_token( + self, mock_get_token_data: MagicMock, app: Flask, caplog: pytest.LogCaptureFixture + ) -> None: + with caplog.at_level(logging.WARNING, logger="controllers.web.login"): with app.test_request_context( "/web/email-code-login/validity", method="POST", @@ -182,9 +186,7 @@ class TestLoginApi: EmailCodeLoginApi().post() mock_get_token_data.assert_called_once_with("token-123") - assert mock_log_warning.call_count == 1 - assert mock_log_warning.call_args.args[1] == "user@example.com" - assert mock_log_warning.call_args.args[2] == LoginFailureReason.INVALID_EMAIL_CODE_TOKEN + assert_login_failure_logged(caplog, "user@example.com", LoginFailureReason.INVALID_EMAIL_CODE_TOKEN) @patch("controllers.web.login.WebAppAuthService.revoke_email_code_login_token") @patch( @@ -201,10 +203,11 @@ class TestLoginApi: mock_get_user: MagicMock, mock_revoke_token: MagicMock, app: Flask, + caplog: pytest.LogCaptureFixture, ) -> None: from controllers.console.error import AccountBannedError - with patch("controllers.web.login.logger.warning") as mock_log_warning: + with caplog.at_level(logging.WARNING, logger="controllers.web.login"): with app.test_request_context( "/web/email-code-login/validity", method="POST", @@ -215,9 +218,7 @@ class TestLoginApi: mock_get_token_data.assert_called_once_with("token-123") mock_revoke_token.assert_called_once_with("token-123") - assert mock_log_warning.call_count == 1 - assert mock_log_warning.call_args.args[1] == "user@example.com" - assert mock_log_warning.call_args.args[2] == LoginFailureReason.ACCOUNT_BANNED + assert_login_failure_logged(caplog, "user@example.com", LoginFailureReason.ACCOUNT_BANNED) class TestLoginStatusApi: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 28f416ac27f..4285fe088c6 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -29,6 +29,7 @@ from core.app.entities.queue_entities import ( QueueNodeExceptionEvent, QueueNodeFailedEvent, QueuePingEvent, + QueueReasoningChunkEvent, QueueRetrieverResourcesEvent, QueueStopEvent, QueueTextChunkEvent, @@ -46,6 +47,7 @@ from core.app.entities.task_entities import ( MessageAudioStreamResponse, MessageEndStreamResponse, PingStreamResponse, + ReasoningChunkStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk from core.workflow.system_variables import build_system_variables @@ -196,6 +198,42 @@ class TestAdvancedChatGenerateTaskPipeline: assert pipeline._task_state.answer == "hi" assert responses + def test_handle_reasoning_chunk_event_emits_on_nonempty(self): + pipeline = _make_pipeline() + event = QueueReasoningChunkEvent(reasoning="pondering", from_node_id="llm-1", is_final=False) + + responses = list(pipeline._handle_reasoning_chunk_event(event)) + + assert len(responses) == 1 + response = responses[0] + assert isinstance(response, ReasoningChunkStreamResponse) + assert response.data.message_id == pipeline._message_id + assert response.data.reasoning == "pondering" + assert response.data.node_id == "llm-1" + assert response.data.is_final is False + # reasoning never touches the answer stream + assert pipeline._task_state.answer == "" + + def test_handle_reasoning_chunk_event_drops_empty_nonfinal(self): + pipeline = _make_pipeline() + event = QueueReasoningChunkEvent(reasoning="", from_node_id="llm-1", is_final=False) + + responses = list(pipeline._handle_reasoning_chunk_event(event)) + + assert responses == [] + + def test_handle_reasoning_chunk_event_emits_empty_final_marker(self): + pipeline = _make_pipeline() + event = QueueReasoningChunkEvent(reasoning="", from_node_id="llm-1", is_final=True) + + responses = list(pipeline._handle_reasoning_chunk_event(event)) + + assert len(responses) == 1 + response = responses[0] + assert isinstance(response, ReasoningChunkStreamResponse) + assert response.data.reasoning == "" + assert response.data.is_final is True + def test_listen_audio_msg_returns_audio_stream(self): pipeline = _make_pipeline() publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data")) @@ -319,6 +357,43 @@ class TestAdvancedChatGenerateTaskPipeline: assert responses == ["done"] assert pipeline._recorded_files + def test_handle_node_succeeded_event_records_llm_reasoning(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [] + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + pipeline._save_output_for_event = lambda event, node_execution_id: None + + event = SimpleNamespace( + node_type=BuiltinNodeTypes.LLM, + outputs={"reasoning_content": "first pass "}, + node_execution_id="exec", + node_id="llm-1", + ) + + list(pipeline._handle_node_succeeded_event(event)) + + assert pipeline._task_state.metadata.reasoning == {"llm-1": "first pass "} + + def test_handle_node_succeeded_event_accumulates_reasoning_across_passes(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [] + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + pipeline._save_output_for_event = lambda event, node_execution_id: None + + def _llm_event(reasoning: str): + return SimpleNamespace( + node_type=BuiltinNodeTypes.LLM, + outputs={"reasoning_content": reasoning}, + node_execution_id="exec", + node_id="llm-1", + ) + + # Same node id across iteration/loop passes must accumulate, not overwrite. + list(pipeline._handle_node_succeeded_event(_llm_event("pass one "))) + list(pipeline._handle_node_succeeded_event(_llm_event("pass two"))) + + assert pipeline._task_state.metadata.reasoning == {"llm-1": "pass one pass two"} + def test_iteration_and_loop_handlers(self): pipeline = _make_pipeline() pipeline._workflow_run_id = "run-id" diff --git a/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py b/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py index 0d1483e1b79..4f292d90bb4 100644 --- a/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py +++ b/api/tests/unit_tests/core/app/apps/agent_app/test_runtime_request_builder.py @@ -144,6 +144,22 @@ class TestAgentAppRuntimeRequestBuilder: assert result.redacted_request["composition"]["layers"][-1]["config"]["credentials"] == "[REDACTED]" assert result.metadata["conversation_id"] == "conv-1" + def test_build_normalizes_marketplace_model_plugin_id(self): + soul = _soul_with_model() + soul.model.plugin_id = ( + "langgenius/openai:0.4.2@21195ee1321849e0a7d4b3f6b2fd8c2be23ea6c7182e1b444ecc4c1711b52468" + ) + builder = AgentAppRuntimeRequestBuilder( + credentials_provider=_FakeCredentialsProvider(), + plugin_tools_builder=_NoToolsBuilder(), # type: ignore[arg-type] + ) + + result = builder.build(_ctx(soul)) + + llm = next(layer for layer in result.request.composition.layers if layer.name == "llm") + assert llm.config.plugin_id == "langgenius/openai" + assert llm.config.model_provider == "openai" + def test_build_maps_agent_soul_knowledge_to_knowledge_layer(self): soul = AgentSoulConfig.model_validate( { diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index c463c155a52..69ed5919d27 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -16,6 +16,7 @@ from core.app.entities.queue_entities import ( QueueNodeFailedEvent, QueueNodeRetryEvent, QueueNodeSucceededEvent, + QueueReasoningChunkEvent, QueueTextChunkEvent, QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, @@ -34,6 +35,7 @@ from graphon.graph_events import ( NodeRunHumanInputFormFilledEvent, NodeRunIterationSucceededEvent, NodeRunLoopFailedEvent, + NodeRunReasoningChunkEvent, NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, @@ -395,6 +397,17 @@ class TestWorkflowBasedAppRunner: is_final=False, ), ) + runner._handle_event( + workflow_entry, + NodeRunReasoningChunkEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + selector=["node", "reasoning_content"], + chunk="thinking", + is_final=False, + ), + ) runner._handle_event( workflow_entry, NodeRunAgentLogEvent( @@ -442,6 +455,7 @@ class TestWorkflowBasedAppRunner: ) assert any(isinstance(event, QueueTextChunkEvent) for event in published) + assert any(isinstance(event, QueueReasoningChunkEvent) for event in published) assert any(isinstance(event, QueueAgentLogEvent) for event in published) assert any(isinstance(event, QueueIterationCompletedEvent) for event in published) assert any(isinstance(event, QueueLoopCompletedEvent) for event in published) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index 0aaee900e37..9a04014d620 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -26,6 +26,7 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, QueuePingEvent, + QueueReasoningChunkEvent, QueueStopEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, @@ -40,6 +41,7 @@ from core.app.entities.task_entities import ( MessageAudioEndStreamResponse, MessageAudioStreamResponse, PingStreamResponse, + ReasoningChunkStreamResponse, WorkflowAppPausedBlockingResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, @@ -265,6 +267,41 @@ class TestWorkflowGenerateTaskPipeline: assert responses[0].data.text == "hi" assert published == [queue_message] + def test_handle_reasoning_chunk_event_emits_on_nonempty(self): + pipeline = _make_pipeline() + event = QueueReasoningChunkEvent(reasoning="pondering", from_node_id="llm-1", is_final=False) + + responses = list(pipeline._handle_reasoning_chunk_event(event)) + + assert len(responses) == 1 + response = responses[0] + assert isinstance(response, ReasoningChunkStreamResponse) + # workflow runs have no message, so the id is omitted + assert response.data.message_id is None + assert response.data.reasoning == "pondering" + assert response.data.node_id == "llm-1" + assert response.data.is_final is False + + def test_handle_reasoning_chunk_event_drops_empty_nonfinal(self): + pipeline = _make_pipeline() + event = QueueReasoningChunkEvent(reasoning="", from_node_id="llm-1", is_final=False) + + responses = list(pipeline._handle_reasoning_chunk_event(event)) + + assert responses == [] + + def test_handle_reasoning_chunk_event_emits_empty_final_marker(self): + pipeline = _make_pipeline() + event = QueueReasoningChunkEvent(reasoning="", from_node_id="llm-1", is_final=True) + + responses = list(pipeline._handle_reasoning_chunk_event(event)) + + assert len(responses) == 1 + response = responses[0] + assert isinstance(response, ReasoningChunkStreamResponse) + assert response.data.reasoning == "" + assert response.data.is_final is True + def test_dispatch_event_handles_node_failed(self): pipeline = _make_pipeline() pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" diff --git a/api/tests/unit_tests/core/app/entities/test_queue_entities.py b/api/tests/unit_tests/core/app/entities/test_queue_entities.py index 7c21b00966e..a930d7325d4 100644 --- a/api/tests/unit_tests/core/app/entities/test_queue_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_queue_entities.py @@ -1,4 +1,4 @@ -from core.app.entities.queue_entities import QueueStopEvent +from core.app.entities.queue_entities import QueueEvent, QueueReasoningChunkEvent, QueueStopEvent class TestQueueEntities: @@ -10,3 +10,17 @@ class TestQueueEntities: event = QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) event.stopped_by = "unknown" assert event.get_stop_reason() == "Stopped by unknown reason." + + def test_reasoning_chunk_event_defaults(self): + event = QueueReasoningChunkEvent(reasoning="thinking", from_node_id="llm") + assert event.event == QueueEvent.REASONING_CHUNK + assert event.reasoning == "thinking" + assert event.from_node_id == "llm" + assert event.is_final is False + assert event.in_iteration_id is None + assert event.in_loop_id is None + + def test_reasoning_chunk_event_terminal_marker_allows_empty_reasoning(self): + event = QueueReasoningChunkEvent(reasoning="", from_node_id="llm", is_final=True) + assert event.reasoning == "" + assert event.is_final is True diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py index 7c797806411..19d49a28877 100644 --- a/api/tests/unit_tests/core/app/entities/test_task_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -1,10 +1,15 @@ +import json + from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, + ReasoningChunkStreamResponse, StreamEvent, + TaskStateMetadata, ) from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder class TestTaskEntities: @@ -76,3 +81,41 @@ class TestTaskEntities: assert payload["event"] == StreamEvent.NODE_RETRY.value assert payload["data"]["retry_index"] == 2 assert payload["data"]["outputs"] is None + + def test_reasoning_chunk_stream_response_shape(self): + response = ReasoningChunkStreamResponse( + task_id="task-1", + data=ReasoningChunkStreamResponse.Data( + message_id="msg-1", + reasoning="let me think", + node_id="llm", + is_final=False, + ), + ) + + payload = response.model_dump() + + assert payload["event"] == StreamEvent.REASONING_CHUNK + assert payload["task_id"] == "task-1" + assert payload["data"]["message_id"] == "msg-1" + assert payload["data"]["reasoning"] == "let me think" + assert payload["data"]["node_id"] == "llm" + assert payload["data"]["is_final"] is False + + def test_task_state_metadata_reasoning_round_trips(self): + # The persistence path serializes the whole metadata to message_metadata via + # model_dump -> jsonable_encoder -> json.dumps, then reads back with json.loads. + metadata = TaskStateMetadata() + metadata.reasoning["llm"] = "first" + metadata.reasoning["llm2"] = "second" + + serialized = json.dumps(jsonable_encoder(metadata.model_dump())) + restored = json.loads(serialized) + + assert restored["reasoning"] == {"llm": "first", "llm2": "second"} + + def test_task_state_metadata_reasoning_defaults_empty(self): + # Old rows / runs without reasoning serialize to an empty dict, never null. + metadata = TaskStateMetadata() + restored = json.loads(json.dumps(jsonable_encoder(metadata.model_dump()))) + assert restored["reasoning"] == {} diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 07ba9314977..dbc8f4969dc 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from contextlib import contextmanager from types import SimpleNamespace from typing import Any @@ -1704,13 +1705,14 @@ def test_get_specific_provider_credential_decrypts_and_obfuscates_credentials() assert credentials == {"openai_api_key": "raw-secret", "region": "us"} -def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: +def test_get_specific_provider_credential_logs_when_decrypt_fails(caplog: pytest.LogCaptureFixture) -> None: configuration = _build_provider_configuration() configuration.provider.provider_credential_schema = _build_secret_provider_schema() session = Mock() session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( encrypted_config='{"openai_api_key":"enc-secret"}' ) + caplog.set_level(logging.ERROR, logger="core.entities.provider_configuration") with _patched_session(session): with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): @@ -1718,16 +1720,15 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: "core.entities.provider_configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom"), ): - with patch("core.entities.provider_configuration.logger.exception") as mock_logger: - with patch.object( - ProviderConfiguration, - "obfuscated_credentials", - side_effect=lambda credentials, credential_form_schemas: credentials, - ): - credentials = configuration._get_specific_provider_credential("cred-1") + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + credentials = configuration._get_specific_provider_credential("cred-1") assert credentials == {"openai_api_key": "enc-secret"} - mock_logger.assert_called_once() + assert caplog.messages.count("Failed to decrypt credential secret variable openai_api_key") == 1 def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: @@ -1831,7 +1832,7 @@ def test_switch_active_provider_credential_rolls_back_on_error() -> None: session.rollback.assert_called_once() -def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: +def test_get_specific_custom_model_credential_logs_when_decrypt_fails(caplog: pytest.LogCaptureFixture) -> None: configuration = _build_provider_configuration() configuration.provider.model_credential_schema = _build_secret_model_schema() session = Mock() @@ -1840,19 +1841,19 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: credential_name="Main", encrypted_config='{"openai_api_key":"enc-secret"}', ) + caplog.set_level(logging.ERROR, logger="core.entities.provider_configuration") with _patched_session(session): with patch("core.entities.provider_configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")): - with patch("core.entities.provider_configuration.logger.exception") as mock_logger: - with patch.object( - ProviderConfiguration, - "obfuscated_credentials", - side_effect=lambda credentials, credential_form_schemas: credentials, - ): - result = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + result = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") assert result["credentials"] == {"openai_api_key": "enc-secret"} - mock_logger.assert_called_once() + assert caplog.messages.count("Failed to decrypt model credential secret variable openai_api_key") == 1 def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 42d5ea4a393..168d0e466d4 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -407,7 +407,7 @@ class TestCacheEmbeddingDocuments: assert len(calls[1].kwargs["texts"]) == 10 assert len(calls[2].kwargs["texts"]) == 5 - def test_embed_documents_nan_handling(self, mock_model_instance, caplog): + def test_embed_documents_nan_handling(self, mock_model_instance, caplog: pytest.LogCaptureFixture): """Test handling of NaN values in embeddings. Verifies: diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index d2154f138a7..302ababb48f 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -385,7 +385,7 @@ class TestParagraphIndexProcessor: with pytest.raises(ValueError, match="model_name and model_provider_name"): ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True}) - def test_generate_summary_text_only_flow(self, caplog) -> None: + def test_generate_summary_text_only_flow(self, caplog: pytest.LogCaptureFixture) -> None: model_instance = Mock() model_instance.credentials = {"k": "v"} model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) @@ -459,7 +459,7 @@ class TestParagraphIndexProcessor: assert summary == "vision summary" mock_extract_text.assert_not_called() - def test_generate_summary_fallbacks_for_prompt_and_result_types(self, caplog) -> None: + def test_generate_summary_fallbacks_for_prompt_and_result_types(self, caplog: pytest.LogCaptureFixture) -> None: model_instance = Mock() model_instance.credentials = {"k": "v"} model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( @@ -503,7 +503,7 @@ class TestParagraphIndexProcessor: "Failed to convert image file to prompt message content" in record.message for record in caplog.records ) - def test_extract_images_from_text_handles_patterns_and_build_errors(self, caplog) -> None: + def test_extract_images_from_text_handles_patterns_and_build_errors(self, caplog: pytest.LogCaptureFixture) -> None: text = ( "![img](/files/11111111-1111-1111-1111-111111111111/image-preview) " "![img2](/files/22222222-2222-2222-2222-222222222222/file-preview) " @@ -554,7 +554,7 @@ class TestParagraphIndexProcessor: session.scalars.return_value = scalars_result assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here", session) == [] - def test_extract_images_from_text_logs_when_build_fails(self, caplog) -> None: + def test_extract_images_from_text_logs_when_build_fails(self, caplog: pytest.LogCaptureFixture) -> None: text = "![img](/files/11111111-1111-1111-1111-111111111111/image-preview)" image_upload = SimpleNamespace( id="11111111-1111-1111-1111-111111111111", @@ -583,7 +583,7 @@ class TestParagraphIndexProcessor: assert files == [] assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1 - def test_extract_images_from_segment_attachments(self, caplog) -> None: + def test_extract_images_from_segment_attachments(self, caplog: pytest.LogCaptureFixture) -> None: image_upload = SimpleNamespace( id="file-1", name="image", diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py index 4ffd0a76433..6e5a4fabbb0 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -351,7 +351,9 @@ class TestQAIndexProcessor: assert all_qa_documents[0].metadata["answer"] == "A test." assert all_qa_documents[1].metadata["answer"] == "Coverage." - def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app, caplog) -> None: + def test_format_qa_document_logs_errors( + self, processor: QAIndexProcessor, fake_flask_app, caplog: pytest.LogCaptureFixture + ) -> None: all_qa_documents: list[Document] = [] source_document = Document(page_content="source text", metadata={"origin": "doc-1"}) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 12c7f8113c7..97d7e4a937b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -4,6 +4,8 @@ from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pytest + from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus @@ -122,7 +124,7 @@ def test_precheck_ignores_non_quota_node() -> None: mock_check.assert_not_called() -def test_quota_error_is_handled_in_layer(caplog) -> None: +def test_quota_error_is_handled_in_layer(caplog: pytest.LogCaptureFixture) -> None: layer = LLMQuotaLayer(tenant_id="tenant-id") stop_event = threading.Event() layer.command_channel = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py index 78e49769159..ffa7ccdbca7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py +++ b/api/tests/unit_tests/core/workflow/nodes/agent_v2/test_runtime_request_builder.py @@ -189,7 +189,7 @@ def test_normalizes_langgenius_model_provider_for_agent_backend_transport(): context.snapshot.config_snapshot = AgentSoulConfig( prompt={"system_prompt": "You are careful."}, model=AgentSoulModelConfig( - plugin_id="langgenius/openai/openai", + plugin_id="langgenius/openai:0.4.2@21195ee1321849e0a7d4b3f6b2fd8c2be23ea6c7182e1b444ecc4c1711b52468", model_provider="langgenius/openai/openai", model="gpt-test", ), diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 79581cc8811..d437c565949 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -47,7 +47,12 @@ from graphon.model_runtime.entities.model_entities import ( ParameterType, ) from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from graphon.node_events import ModelInvokeCompletedEvent, RunRetrieverResourceEvent, StreamChunkEvent +from graphon.node_events import ( + ModelInvokeCompletedEvent, + RunRetrieverResourceEvent, + StreamChunkEvent, + StreamReasoningEvent, +) from graphon.nodes.base.entities import VariableSelector from graphon.nodes.llm import llm_utils from graphon.nodes.llm.entities import ( @@ -1576,9 +1581,13 @@ def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_out assert events[0] == first_chunk - assert events[1] == StreamChunkEvent(selector=["node-1", "text"], chunk="answer", is_final=False) + assert events[1] == StreamReasoningEvent(selector=["node-1", "reasoning_content"], chunk="plan", is_final=False) - completed = events[2] + assert events[2] == StreamChunkEvent(selector=["node-1", "text"], chunk="answer", is_final=False) + + assert events[3] == StreamReasoningEvent(selector=["node-1", "reasoning_content"], chunk="", is_final=True) + + completed = events[4] assert isinstance(completed, ModelInvokeCompletedEvent) assert completed.text == "answer" assert completed.reasoning_content == "plan" diff --git a/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py b/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py index 1ce9581aa1c..24c905c75c2 100644 --- a/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py +++ b/api/tests/unit_tests/enterprise/telemetry/test_enterprise_trace.py @@ -21,6 +21,7 @@ from core.ops.entities.trace_entity import ( WorkflowNodeTraceInfo, WorkflowTraceInfo, ) +from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace from enterprise.telemetry.entities import ( EnterpriseTelemetryCounter, EnterpriseTelemetryEvent, @@ -297,43 +298,43 @@ def test_init_succeeds_with_valid_exporter(mock_exporter): class TestSafePayloadValue: - def test_string_passthrough(self, trace_handler): + def test_string_passthrough(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._safe_payload_value("hello") == "hello" - def test_dict_passthrough(self, trace_handler): + def test_dict_passthrough(self, trace_handler: EnterpriseOtelTrace): d = {"key": "val"} assert trace_handler._safe_payload_value(d) == d - def test_list_passthrough(self, trace_handler): + def test_list_passthrough(self, trace_handler: EnterpriseOtelTrace): lst = [1, 2, 3] assert trace_handler._safe_payload_value(lst) == lst - def test_none_returns_none(self, trace_handler): + def test_none_returns_none(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._safe_payload_value(None) is None - def test_int_returns_none(self, trace_handler): + def test_int_returns_none(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._safe_payload_value(42) is None - def test_bool_returns_none(self, trace_handler): + def test_bool_returns_none(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._safe_payload_value(True) is None class TestMaybeJson: - def test_none_returns_none(self, trace_handler): + def test_none_returns_none(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._maybe_json(None) is None - def test_string_passthrough(self, trace_handler): + def test_string_passthrough(self, trace_handler: EnterpriseOtelTrace): assert trace_handler._maybe_json("hello") == "hello" - def test_dict_serialised(self, trace_handler): + def test_dict_serialised(self, trace_handler: EnterpriseOtelTrace): result = trace_handler._maybe_json({"a": 1}) assert result == json.dumps({"a": 1}) - def test_list_serialised(self, trace_handler): + def test_list_serialised(self, trace_handler: EnterpriseOtelTrace): result = trace_handler._maybe_json([1, 2]) assert result == "[1, 2]" - def test_non_serialisable_falls_back_to_str(self, trace_handler): + def test_non_serialisable_falls_back_to_str(self, trace_handler: EnterpriseOtelTrace): class Unserializable: def __repr__(self): return "Unserializable()" @@ -344,22 +345,22 @@ class TestMaybeJson: class TestContentOrRef: - def test_returns_content_when_include_content_true(self, trace_handler, mock_exporter): + def test_returns_content_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True result = trace_handler._content_or_ref("actual content", "ref:x=1") assert result == "actual content" - def test_returns_ref_when_include_content_false(self, trace_handler, mock_exporter): + def test_returns_ref_when_include_content_false(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False result = trace_handler._content_or_ref("actual content", "ref:x=1") assert result == "ref:x=1" - def test_dict_serialised_when_include_content_true(self, trace_handler, mock_exporter): + def test_dict_serialised_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True result = trace_handler._content_or_ref({"key": "val"}, "ref:x=1") assert result == json.dumps({"key": "val"}) - def test_none_returns_none_when_include_content_true(self, trace_handler, mock_exporter): + def test_none_returns_none_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True result = trace_handler._content_or_ref(None, "ref:x=1") assert result is None @@ -371,67 +372,67 @@ class TestContentOrRef: class TestTraceDispatcher: - def test_dispatches_workflow_trace(self, trace_handler): + def test_dispatches_workflow_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_workflow_trace") as mock_method: info = make_workflow_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_message_trace(self, trace_handler): + def test_dispatches_message_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_message_trace") as mock_method: info = make_message_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_tool_trace(self, trace_handler): + def test_dispatches_tool_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_tool_trace") as mock_method: info = make_tool_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_draft_node_execution_trace(self, trace_handler): + def test_dispatches_draft_node_execution_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_draft_node_execution_trace") as mock_method: info = make_draft_node_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_node_execution_trace(self, trace_handler): + def test_dispatches_node_execution_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_node_execution_trace") as mock_method: info = make_node_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_moderation_trace(self, trace_handler): + def test_dispatches_moderation_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_moderation_trace") as mock_method: info = make_moderation_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_suggested_question_trace(self, trace_handler): + def test_dispatches_suggested_question_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_suggested_question_trace") as mock_method: info = make_suggested_question_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_dataset_retrieval_trace(self, trace_handler): + def test_dispatches_dataset_retrieval_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_dataset_retrieval_trace") as mock_method: info = make_dataset_retrieval_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_generate_name_trace(self, trace_handler): + def test_dispatches_generate_name_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_generate_name_trace") as mock_method: info = make_generate_name_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_dispatches_prompt_generation_trace(self, trace_handler): + def test_dispatches_prompt_generation_trace(self, trace_handler: EnterpriseOtelTrace): with patch.object(trace_handler, "_prompt_generation_trace") as mock_method: info = make_prompt_generation_info() trace_handler.trace(info) mock_method.assert_called_once_with(info) - def test_draft_node_dispatched_before_node(self, trace_handler): + def test_draft_node_dispatched_before_node(self, trace_handler: EnterpriseOtelTrace): """DraftNodeExecutionTrace is a subclass of WorkflowNodeTraceInfo; it must be dispatched to _draft_node_execution_trace, not _node_execution_trace.""" with ( @@ -450,7 +451,7 @@ class TestTraceDispatcher: class TestWorkflowTrace: - def test_emits_correct_span_attributes(self, trace_handler, mock_exporter): + def test_emits_correct_span_attributes(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: info = make_workflow_info() trace_handler._workflow_trace(info) @@ -465,7 +466,7 @@ class TestWorkflowTrace: assert attrs["dify.workflow.status"] == "succeeded" assert attrs["gen_ai.usage.total_tokens"] == 100 - def test_span_timing_passed_correctly(self, trace_handler, mock_exporter): + def test_span_timing_passed_correctly(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info() trace_handler._workflow_trace(info) @@ -474,7 +475,7 @@ class TestWorkflowTrace: assert span_call[1]["start_time"] == _T0 assert span_call[1]["end_time"] == _T1 - def test_emits_companion_log_with_event_name(self, trace_handler, mock_exporter): + def test_emits_companion_log_with_event_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._workflow_trace(make_workflow_info()) @@ -482,7 +483,7 @@ class TestWorkflowTrace: assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetryEvent.WORKFLOW_RUN assert mock_log.call_args[1]["tenant_id"] == "tenant-abc" - def test_companion_log_includes_content_when_enabled(self, trace_handler, mock_exporter): + def test_companion_log_includes_content_when_enabled(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._workflow_trace(make_workflow_info()) @@ -491,7 +492,7 @@ class TestWorkflowTrace: assert log_attrs["dify.workflow.inputs"] == json.dumps({"query": "hello"}) assert log_attrs["dify.workflow.outputs"] == json.dumps({"answer": "world"}) - def test_companion_log_uses_ref_when_content_disabled(self, trace_handler, mock_exporter): + def test_companion_log_uses_ref_when_content_disabled(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._workflow_trace(make_workflow_info()) @@ -500,7 +501,7 @@ class TestWorkflowTrace: assert log_attrs["dify.workflow.inputs"].startswith("ref:workflow_run_id=") assert log_attrs["dify.workflow.outputs"].startswith("ref:workflow_run_id=") - def test_increments_token_counter(self, trace_handler, mock_exporter): + def test_increments_token_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._workflow_trace(make_workflow_info()) @@ -510,7 +511,7 @@ class TestWorkflowTrace: assert len(token_calls) == 1 assert token_calls[0][0][1] == 100 - def test_increments_input_and_output_token_counters(self, trace_handler, mock_exporter): + def test_increments_input_and_output_token_counters(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._workflow_trace(make_workflow_info()) @@ -519,7 +520,7 @@ class TestWorkflowTrace: assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names - def test_no_input_token_counter_when_prompt_tokens_zero(self, trace_handler, mock_exporter): + def test_no_input_token_counter_when_prompt_tokens_zero(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info(prompt_tokens=0) trace_handler._workflow_trace(info) @@ -528,7 +529,7 @@ class TestWorkflowTrace: counter_names = [c[0][0] for c in all_calls] assert EnterpriseTelemetryCounter.INPUT_TOKENS not in counter_names - def test_records_workflow_duration_histogram(self, trace_handler, mock_exporter): + def test_records_workflow_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._workflow_trace(make_workflow_info()) @@ -537,7 +538,9 @@ class TestWorkflowTrace: assert hist_call[0][0] == EnterpriseTelemetryHistogram.WORKFLOW_DURATION assert hist_call[0][1] == pytest.approx(5.0) - def test_duration_falls_back_to_elapsed_time_when_timestamps_missing(self, trace_handler, mock_exporter): + def test_duration_falls_back_to_elapsed_time_when_timestamps_missing( + self, trace_handler: EnterpriseOtelTrace, mock_exporter + ): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info(start_time=None, end_time=None, workflow_run_elapsed_time=7.3) trace_handler._workflow_trace(info) @@ -545,7 +548,7 @@ class TestWorkflowTrace: hist_call = mock_exporter.record_histogram.call_args assert hist_call[0][1] == pytest.approx(7.3) - def test_duration_defaults_to_zero_when_no_timing(self, trace_handler, mock_exporter): + def test_duration_defaults_to_zero_when_no_timing(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info(start_time=None, end_time=None, workflow_run_elapsed_time=0) trace_handler._workflow_trace(info) @@ -553,7 +556,7 @@ class TestWorkflowTrace: hist_call = mock_exporter.record_histogram.call_args assert hist_call[0][1] == pytest.approx(0.0) - def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + def test_error_path_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info(error="Something went wrong", workflow_run_status="failed") trace_handler._workflow_trace(info) @@ -563,7 +566,7 @@ class TestWorkflowTrace: ] assert len(error_calls) == 1 - def test_no_error_counter_on_success(self, trace_handler, mock_exporter): + def test_no_error_counter_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._workflow_trace(make_workflow_info()) @@ -572,7 +575,7 @@ class TestWorkflowTrace: ] assert len(error_calls) == 0 - def test_parent_trace_context_injected_into_span_attrs(self, trace_handler, mock_exporter): + def test_parent_trace_context_injected_into_span_attrs(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_workflow_info( metadata={ @@ -601,14 +604,14 @@ class TestWorkflowTrace: class TestNodeExecutionTrace: - def test_emits_span_with_node_execution_span_name(self, trace_handler, mock_exporter): + def test_emits_span_with_node_execution_span_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info()) span_call = mock_exporter.export_span.call_args assert span_call[0][0] == EnterpriseTelemetrySpan.NODE_EXECUTION - def test_span_contains_core_node_attributes(self, trace_handler, mock_exporter): + def test_span_contains_core_node_attributes(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info()) @@ -620,7 +623,7 @@ class TestNodeExecutionTrace: assert attrs["gen_ai.request.model"] == "gpt-4" assert attrs["gen_ai.provider.name"] == "openai" - def test_increments_token_counters_when_tokens_present(self, trace_handler, mock_exporter): + def test_increments_token_counters_when_tokens_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info()) @@ -629,7 +632,7 @@ class TestNodeExecutionTrace: assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names - def test_no_token_counters_when_total_tokens_zero(self, trace_handler, mock_exporter): + def test_no_token_counters_when_total_tokens_zero(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info(total_tokens=0)) @@ -637,7 +640,7 @@ class TestNodeExecutionTrace: assert EnterpriseTelemetryCounter.TOKENS not in counter_names assert EnterpriseTelemetryCounter.INPUT_TOKENS not in counter_names - def test_records_node_duration_histogram(self, trace_handler, mock_exporter): + def test_records_node_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info()) @@ -645,7 +648,7 @@ class TestNodeExecutionTrace: assert hist_call[0][0] == EnterpriseTelemetryHistogram.NODE_DURATION assert hist_call[0][1] == pytest.approx(2.5) - def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + def test_error_path_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._node_execution_trace(make_node_info(error="Node failed", status="failed")) @@ -654,14 +657,16 @@ class TestNodeExecutionTrace: ] assert len(error_calls) == 1 - def test_emits_companion_log_with_span_name_as_event(self, trace_handler, mock_exporter): + def test_emits_companion_log_with_span_name_as_event(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._node_execution_trace(make_node_info()) mock_log.assert_called_once() assert mock_log.call_args[1]["event_name"] == EnterpriseTelemetrySpan.NODE_EXECUTION.value - def test_plugin_name_added_to_duration_labels_for_tool_node(self, trace_handler, mock_exporter): + def test_plugin_name_added_to_duration_labels_for_tool_node( + self, trace_handler: EnterpriseOtelTrace, mock_exporter + ): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_node_info( node_type="tool", @@ -677,7 +682,7 @@ class TestNodeExecutionTrace: duration_labels = hist_call[0][2] assert duration_labels.get("plugin_name") == "my-plugin" - def test_plugin_name_not_added_for_non_tool_node(self, trace_handler, mock_exporter): + def test_plugin_name_not_added_for_non_tool_node(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_node_info( node_type="llm", @@ -693,7 +698,9 @@ class TestNodeExecutionTrace: duration_labels = hist_call[0][2] assert "plugin_name" not in duration_labels - def test_companion_log_inputs_use_ref_when_content_disabled(self, trace_handler, mock_exporter): + def test_companion_log_inputs_use_ref_when_content_disabled( + self, trace_handler: EnterpriseOtelTrace, mock_exporter + ): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._node_execution_trace( @@ -711,14 +718,14 @@ class TestNodeExecutionTrace: class TestDraftNodeExecutionTrace: - def test_uses_draft_span_name(self, trace_handler, mock_exporter): + def test_uses_draft_span_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): trace_handler._draft_node_execution_trace(make_draft_node_info()) span_call = mock_exporter.export_span.call_args assert span_call[0][0] == EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION - def test_correlation_id_is_node_execution_id(self, trace_handler, mock_exporter): + def test_correlation_id_is_node_execution_id(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_draft_node_info() trace_handler._draft_node_execution_trace(info) @@ -726,7 +733,7 @@ class TestDraftNodeExecutionTrace: span_call = mock_exporter.export_span.call_args assert span_call[1]["correlation_id"] == "ne-draft-001" - def test_trace_correlation_override_is_workflow_run_id(self, trace_handler, mock_exporter): + def test_trace_correlation_override_is_workflow_run_id(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log"): info = make_draft_node_info() trace_handler._draft_node_execution_trace(info) @@ -734,7 +741,7 @@ class TestDraftNodeExecutionTrace: span_call = mock_exporter.export_span.call_args assert span_call[1]["trace_correlation_override"] == "run-draft-001" - def test_companion_log_uses_draft_span_name(self, trace_handler, mock_exporter): + def test_companion_log_uses_draft_span_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_telemetry_log") as mock_log: trace_handler._draft_node_execution_trace(make_draft_node_info()) @@ -747,34 +754,36 @@ class TestDraftNodeExecutionTrace: class TestMessageTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._message_trace(make_message_info()) mock_emit.assert_called_once() assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.MESSAGE_RUN - def test_emits_correct_tenant_and_user(self, trace_handler, mock_exporter): + def test_emits_correct_tenant_and_user(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._message_trace(make_message_info()) assert mock_emit.call_args[1]["tenant_id"] == "tenant-abc" - def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + def test_duration_computed_from_timestamps(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._message_trace(make_message_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.message.duration"] == pytest.approx(5.0) - def test_no_duration_when_timestamps_missing(self, trace_handler, mock_exporter): + def test_no_duration_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._message_trace(make_message_info(start_time=None, end_time=None)) attrs = mock_emit.call_args[1]["attributes"] assert "dify.message.duration" not in attrs - def test_records_duration_histogram_when_timestamps_present(self, trace_handler, mock_exporter): + def test_records_duration_histogram_when_timestamps_present( + self, trace_handler: EnterpriseOtelTrace, mock_exporter + ): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info()) @@ -786,14 +795,14 @@ class TestMessageTrace: assert len(hist_calls) == 1 assert hist_calls[0][0][1] == pytest.approx(5.0) - def test_no_duration_histogram_when_timestamps_missing(self, trace_handler, mock_exporter): + def test_no_duration_histogram_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info(start_time=None, end_time=None)) hist_names = [c[0][0] for c in mock_exporter.record_histogram.call_args_list] assert EnterpriseTelemetryHistogram.MESSAGE_DURATION not in hist_names - def test_records_ttft_histogram_when_present(self, trace_handler, mock_exporter): + def test_records_ttft_histogram_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info(gen_ai_server_time_to_first_token=0.42)) @@ -805,14 +814,14 @@ class TestMessageTrace: assert len(ttft_calls) == 1 assert ttft_calls[0][0][1] == pytest.approx(0.42) - def test_no_ttft_histogram_when_not_present(self, trace_handler, mock_exporter): + def test_no_ttft_histogram_when_not_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info(gen_ai_server_time_to_first_token=None)) hist_names = [c[0][0] for c in mock_exporter.record_histogram.call_args_list] assert EnterpriseTelemetryHistogram.MESSAGE_TTFT not in hist_names - def test_increments_token_counters(self, trace_handler, mock_exporter): + def test_increments_token_counters(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info()) @@ -821,7 +830,7 @@ class TestMessageTrace: assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names - def test_error_path_increments_error_counter(self, trace_handler, mock_exporter): + def test_error_path_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._message_trace(make_message_info(error="LLM failed")) @@ -830,7 +839,7 @@ class TestMessageTrace: ] assert len(error_calls) == 1 - def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._message_trace(make_message_info()) @@ -846,27 +855,27 @@ class TestMessageTrace: class TestToolTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._tool_trace(make_tool_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.TOOL_EXECUTION - def test_status_is_succeeded_on_success(self, trace_handler, mock_exporter): + def test_status_is_succeeded_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._tool_trace(make_tool_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.tool.status"] == "succeeded" - def test_status_is_failed_on_error(self, trace_handler, mock_exporter): + def test_status_is_failed_on_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._tool_trace(make_tool_info(error="Tool error")) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.tool.status"] == "failed" - def test_records_tool_duration_histogram(self, trace_handler, mock_exporter): + def test_records_tool_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._tool_trace(make_tool_info()) @@ -874,7 +883,7 @@ class TestToolTrace: assert hist_call[0][0] == EnterpriseTelemetryHistogram.TOOL_DURATION assert hist_call[0][1] == pytest.approx(1.5) - def test_error_increments_error_counter(self, trace_handler, mock_exporter): + def test_error_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._tool_trace(make_tool_info(error="Tool crashed")) @@ -883,7 +892,7 @@ class TestToolTrace: ] assert len(error_calls) == 1 - def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._tool_trace(make_tool_info()) @@ -892,7 +901,7 @@ class TestToolTrace: assert attrs["dify.tool.inputs"].startswith("ref:message_id=") assert attrs["dify.tool.outputs"].startswith("ref:message_id=") - def test_inputs_present_when_include_content_true(self, trace_handler, mock_exporter): + def test_inputs_present_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._tool_trace(make_tool_info()) @@ -901,7 +910,7 @@ class TestToolTrace: assert attrs["dify.tool.inputs"] == json.dumps({"query": "test"}) assert attrs["dify.tool.outputs"] == "search results" - def test_increments_requests_counter(self, trace_handler, mock_exporter): + def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._tool_trace(make_tool_info()) @@ -918,27 +927,27 @@ class TestToolTrace: class TestModerationTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._moderation_trace(make_moderation_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.MODERATION_CHECK - def test_flagged_true_sets_attribute(self, trace_handler, mock_exporter): + def test_flagged_true_sets_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._moderation_trace(make_moderation_info(flagged=True)) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.moderation.flagged"] is True - def test_flagged_false_sets_attribute(self, trace_handler, mock_exporter): + def test_flagged_false_sets_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._moderation_trace(make_moderation_info(flagged=False)) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.moderation.flagged"] is False - def test_query_gated_by_include_content(self, trace_handler, mock_exporter): + def test_query_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._moderation_trace(make_moderation_info()) @@ -946,7 +955,7 @@ class TestModerationTrace: attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.moderation.query"].startswith("ref:message_id=") - def test_query_present_when_include_content_true(self, trace_handler, mock_exporter): + def test_query_present_when_include_content_true(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = True with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._moderation_trace(make_moderation_info()) @@ -954,7 +963,7 @@ class TestModerationTrace: attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.moderation.query"] == "is this ok?" - def test_increments_requests_counter(self, trace_handler, mock_exporter): + def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._moderation_trace(make_moderation_info()) @@ -971,48 +980,48 @@ class TestModerationTrace: class TestSuggestedQuestionTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION - def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + def test_duration_computed_from_timestamps(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.duration"] == pytest.approx(5.0) - def test_duration_is_none_when_timestamps_missing(self, trace_handler, mock_exporter): + def test_duration_is_none_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info(start_time=None, end_time=None)) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.duration"] is None - def test_status_is_failed_when_error_present(self, trace_handler, mock_exporter): + def test_status_is_failed_when_error_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info(error="Generation failed")) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.status"] == "failed" - def test_status_falls_back_to_succeeded_when_no_error(self, trace_handler, mock_exporter): + def test_status_falls_back_to_succeeded_when_no_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info(status=None, error=None)) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.status"] == "succeeded" - def test_question_count_attribute(self, trace_handler, mock_exporter): + def test_question_count_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.count"] == 2 - def test_questions_gated_by_include_content(self, trace_handler, mock_exporter): + def test_questions_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._suggested_question_trace(make_suggested_question_info()) @@ -1020,7 +1029,7 @@ class TestSuggestedQuestionTrace: attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.suggested_question.questions"].startswith("ref:message_id=") - def test_increments_requests_counter(self, trace_handler, mock_exporter): + def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._suggested_question_trace(make_suggested_question_info()) @@ -1037,48 +1046,48 @@ class TestSuggestedQuestionTrace: class TestDatasetRetrievalTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.DATASET_RETRIEVAL - def test_document_count_attribute(self, trace_handler, mock_exporter): + def test_document_count_attribute(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.retrieval.document_count"] == 1 - def test_dataset_ids_extracted(self, trace_handler, mock_exporter): + def test_dataset_ids_extracted(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) attrs = mock_emit.call_args[1]["attributes"] assert "ds-001" in attrs["dify.dataset.id"] - def test_empty_documents_has_zero_count(self, trace_handler, mock_exporter): + def test_empty_documents_has_zero_count(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(documents=[])) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.retrieval.document_count"] == 0 - def test_status_succeeded_when_no_error(self, trace_handler, mock_exporter): + def test_status_succeeded_when_no_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.retrieval.status"] == "succeeded" - def test_status_failed_when_error_present(self, trace_handler, mock_exporter): + def test_status_failed_when_error_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(error="DB error")) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.retrieval.status"] == "failed" - def test_embedding_model_attributes_set_when_present(self, trace_handler, mock_exporter): + def test_embedding_model_attributes_set_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) @@ -1086,7 +1095,7 @@ class TestDatasetRetrievalTrace: assert "dify.dataset.embedding_providers" in attrs assert "dify.dataset.embedding_models" in attrs - def test_no_embedding_model_attributes_when_not_provided(self, trace_handler, mock_exporter): + def test_no_embedding_model_attributes_when_not_provided(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace( make_dataset_retrieval_info(metadata={"app_id": "app-001", "tenant_id": "tenant-abc"}) @@ -1096,7 +1105,7 @@ class TestDatasetRetrievalTrace: assert "dify.dataset.embedding_providers" not in attrs assert "dify.dataset.embedding_models" not in attrs - def test_rerank_attributes_set_when_present(self, trace_handler, mock_exporter): + def test_rerank_attributes_set_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace( make_dataset_retrieval_info( @@ -1113,7 +1122,7 @@ class TestDatasetRetrievalTrace: assert attrs["dify.retrieval.rerank_provider"] == "cohere" assert attrs["dify.retrieval.rerank_model"] == "rerank-english" - def test_no_rerank_attributes_when_not_present(self, trace_handler, mock_exporter): + def test_no_rerank_attributes_when_not_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace( make_dataset_retrieval_info(metadata={"app_id": "app-001", "tenant_id": "tenant-abc"}) @@ -1123,7 +1132,7 @@ class TestDatasetRetrievalTrace: assert "dify.retrieval.rerank_provider" not in attrs assert "dify.retrieval.rerank_model" not in attrs - def test_dataset_retrieval_counter_incremented_per_dataset(self, trace_handler, mock_exporter): + def test_dataset_retrieval_counter_incremented_per_dataset(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) @@ -1135,7 +1144,7 @@ class TestDatasetRetrievalTrace: assert len(ds_calls) == 1 assert ds_calls[0][0][2]["dataset_id"] == "ds-001" - def test_no_dataset_retrieval_counter_when_no_documents(self, trace_handler, mock_exporter): + def test_no_dataset_retrieval_counter_when_no_documents(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info(documents=[])) @@ -1146,7 +1155,7 @@ class TestDatasetRetrievalTrace: ] assert len(ds_calls) == 0 - def test_query_gated_by_include_content(self, trace_handler, mock_exporter): + def test_query_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._dataset_retrieval_trace(make_dataset_retrieval_info()) @@ -1161,34 +1170,34 @@ class TestDatasetRetrievalTrace: class TestGenerateNameTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace(make_generate_name_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION - def test_duration_computed_from_timestamps(self, trace_handler, mock_exporter): + def test_duration_computed_from_timestamps(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace(make_generate_name_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.generate_name.duration"] == pytest.approx(5.0) - def test_no_duration_when_timestamps_missing(self, trace_handler, mock_exporter): + def test_no_duration_when_timestamps_missing(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace(make_generate_name_info(start_time=None, end_time=None)) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.generate_name.duration"] is None - def test_status_succeeded_on_success(self, trace_handler, mock_exporter): + def test_status_succeeded_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace(make_generate_name_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.generate_name.status"] == "succeeded" - def test_status_failed_when_metadata_has_error(self, trace_handler, mock_exporter): + def test_status_failed_when_metadata_has_error(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace( make_generate_name_info( @@ -1203,7 +1212,7 @@ class TestGenerateNameTrace: attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.generate_name.status"] == "failed" - def test_inputs_and_outputs_gated_by_include_content(self, trace_handler, mock_exporter): + def test_inputs_and_outputs_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._generate_name_trace(make_generate_name_info()) @@ -1212,7 +1221,7 @@ class TestGenerateNameTrace: assert attrs["dify.generate_name.inputs"].startswith("ref:conversation_id=") assert attrs["dify.generate_name.outputs"].startswith("ref:conversation_id=") - def test_increments_requests_counter(self, trace_handler, mock_exporter): + def test_increments_requests_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._generate_name_trace(make_generate_name_info()) @@ -1229,27 +1238,27 @@ class TestGenerateNameTrace: class TestPromptGenerationTrace: - def test_emits_event_with_correct_name(self, trace_handler, mock_exporter): + def test_emits_event_with_correct_name(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info()) assert mock_emit.call_args[1]["event_name"] == EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION - def test_status_succeeded_on_success(self, trace_handler, mock_exporter): + def test_status_succeeded_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info()) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.prompt_generation.status"] == "succeeded" - def test_status_failed_when_error_present(self, trace_handler, mock_exporter): + def test_status_failed_when_error_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info(error="Generation error")) attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.prompt_generation.status"] == "failed" - def test_token_counters_incremented(self, trace_handler, mock_exporter): + def test_token_counters_incremented(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._prompt_generation_trace(make_prompt_generation_info()) @@ -1258,7 +1267,7 @@ class TestPromptGenerationTrace: assert EnterpriseTelemetryCounter.INPUT_TOKENS in counter_names assert EnterpriseTelemetryCounter.OUTPUT_TOKENS in counter_names - def test_records_duration_histogram(self, trace_handler, mock_exporter): + def test_records_duration_histogram(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._prompt_generation_trace(make_prompt_generation_info()) @@ -1270,7 +1279,7 @@ class TestPromptGenerationTrace: assert len(hist_calls) == 1 assert hist_calls[0][0][1] == pytest.approx(3.2) - def test_total_price_attribute_set_when_present(self, trace_handler, mock_exporter): + def test_total_price_attribute_set_when_present(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info(total_price=0.05, currency="USD")) @@ -1278,14 +1287,14 @@ class TestPromptGenerationTrace: assert attrs["dify.prompt_generation.total_price"] == pytest.approx(0.05) assert attrs["dify.prompt_generation.currency"] == "USD" - def test_no_total_price_attribute_when_none(self, trace_handler, mock_exporter): + def test_no_total_price_attribute_when_none(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info(total_price=None)) attrs = mock_emit.call_args[1]["attributes"] assert "dify.prompt_generation.total_price" not in attrs - def test_error_increments_error_counter(self, trace_handler, mock_exporter): + def test_error_increments_error_counter(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._prompt_generation_trace(make_prompt_generation_info(error="Prompt failed")) @@ -1294,7 +1303,7 @@ class TestPromptGenerationTrace: ] assert len(error_calls) == 1 - def test_no_error_counter_on_success(self, trace_handler, mock_exporter): + def test_no_error_counter_on_success(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._prompt_generation_trace(make_prompt_generation_info()) @@ -1303,7 +1312,7 @@ class TestPromptGenerationTrace: ] assert len(error_calls) == 0 - def test_instruction_gated_by_include_content(self, trace_handler, mock_exporter): + def test_instruction_gated_by_include_content(self, trace_handler: EnterpriseOtelTrace, mock_exporter): mock_exporter.include_content = False with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info()) @@ -1311,7 +1320,7 @@ class TestPromptGenerationTrace: attrs = mock_emit.call_args[1]["attributes"] assert attrs["dify.prompt_generation.instruction"].startswith("ref:trace_id=") - def test_operation_type_label_used_in_token_counters(self, trace_handler, mock_exporter): + def test_operation_type_label_used_in_token_counters(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event"): trace_handler._prompt_generation_trace(make_prompt_generation_info(operation_type="code_generate")) @@ -1321,7 +1330,7 @@ class TestPromptGenerationTrace: assert len(token_calls) == 1 assert token_calls[0][0][2]["operation_type"] == "code_generate" - def test_emits_correct_tenant_id(self, trace_handler, mock_exporter): + def test_emits_correct_tenant_id(self, trace_handler: EnterpriseOtelTrace, mock_exporter): with patch("enterprise.telemetry.enterprise_trace.emit_metric_only_event") as mock_emit: trace_handler._prompt_generation_trace(make_prompt_generation_info()) diff --git a/api/tests/unit_tests/enterprise/telemetry/test_exporter.py b/api/tests/unit_tests/enterprise/telemetry/test_exporter.py index 674a2026131..e04046a1f0d 100644 --- a/api/tests/unit_tests/enterprise/telemetry/test_exporter.py +++ b/api/tests/unit_tests/enterprise/telemetry/test_exporter.py @@ -7,6 +7,8 @@ from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch +import pytest + from configs.enterprise import EnterpriseTelemetryConfig from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram from enterprise.telemetry.exporter import EnterpriseExporter, _datetime_to_ns, _parse_otlp_headers @@ -533,7 +535,7 @@ def test_export_span_cross_workflow_parent_context() -> None: assert kwargs["context"] is not None -def test_export_span_logs_exception_on_error(caplog) -> None: +def test_export_span_logs_exception_on_error(caplog: pytest.LogCaptureFixture) -> None: """If the span block raises, the exception is logged and context is still cleared.""" exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() @@ -546,7 +548,7 @@ def test_export_span_logs_exception_on_error(caplog) -> None: assert "bad.span" in caplog.text -def test_export_span_invalid_trace_correlation_logs_warning(caplog) -> None: +def test_export_span_invalid_trace_correlation_logs_warning(caplog: pytest.LogCaptureFixture) -> None: """Invalid UUID for trace_correlation_override triggers a warning log.""" exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer() diff --git a/api/tests/unit_tests/events/test_update_provider_when_message_created.py b/api/tests/unit_tests/events/test_update_provider_when_message_created.py index 4b2a6438f44..f9ac5d9678e 100644 --- a/api/tests/unit_tests/events/test_update_provider_when_message_created.py +++ b/api/tests/unit_tests/events/test_update_provider_when_message_created.py @@ -4,6 +4,7 @@ from types import SimpleNamespace from unittest.mock import patch from uuid import uuid4 +import pytest from sqlalchemy import create_engine, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -122,7 +123,9 @@ def test_message_created_paid_credit_accounting_uses_paid_pool() -> None: ) -def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted(caplog) -> None: +def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted( + caplog: pytest.LogCaptureFixture, +) -> None: with patch( "services.credit_pool_service.CreditPoolService.deduct_credits_capped", return_value=3, diff --git a/api/tests/unit_tests/extensions/test_ext_request_logging.py b/api/tests/unit_tests/extensions/test_ext_request_logging.py index 3d2f8541f63..70e80707882 100644 --- a/api/tests/unit_tests/extensions/test_ext_request_logging.py +++ b/api/tests/unit_tests/extensions/test_ext_request_logging.py @@ -56,18 +56,19 @@ def mock_response_receiver(monkeypatch: pytest.MonkeyPatch) -> mock.Mock: return mock_log_request_finished -@pytest.fixture -def mock_logger(monkeypatch: pytest.MonkeyPatch) -> logging.Logger: - _logger = mock.MagicMock(spec=logging.Logger) - monkeypatch.setattr(ext_request_logging, "logger", _logger) - return _logger - - @pytest.fixture def enable_request_logging(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", True) +def _captured_records(caplog: pytest.LogCaptureFixture, level: int) -> list[logging.LogRecord]: + return [ + record + for record in caplog.records + if record.name == ext_request_logging.logger.name and record.levelno == level + ] + + class TestRequestLoggingExtension: def test_receiver_should_not_be_invoked_if_configuration_is_disabled( self, @@ -108,67 +109,74 @@ class TestRequestLoggingExtension: class TestLoggingLevel: @pytest.mark.usefixtures("enable_request_logging") - def test_logging_should_be_skipped_if_level_is_above_debug(self, enable_request_logging, mock_logger): - mock_logger.isEnabledFor.return_value = False + def test_logging_should_be_skipped_if_level_is_above_debug( + self, enable_request_logging, caplog: pytest.LogCaptureFixture + ): + caplog.set_level(logging.INFO, logger=ext_request_logging.logger.name) app = _get_test_app() init_app(app) with app.test_client() as client: client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE}) - mock_logger.debug.assert_not_called() + assert not _captured_records(caplog, logging.DEBUG) class TestRequestReceiverLogging: @pytest.mark.usefixtures("enable_request_logging") - def test_non_json_request(self, enable_request_logging, mock_logger, mock_response_receiver): - mock_logger.isEnabledFor.return_value = True + def test_non_json_request(self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver): + caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name) app = _get_test_app() init_app(app) with app.test_client() as client: client.post("/", data="plain text") - assert mock_logger.debug.call_count == 1 - call_args = mock_logger.debug.call_args[0] - assert "Received Request" in call_args[0] - assert call_args[1] == "POST" - assert call_args[2] == "/" - assert "Request Body" not in call_args[0] + debug_records = _captured_records(caplog, logging.DEBUG) + assert len(debug_records) == 1 + record = debug_records[0] + assert "Received Request" in record.msg + assert record.args == ("POST", "/") + assert "Request Body" not in record.msg @pytest.mark.usefixtures("enable_request_logging") - def test_json_request(self, enable_request_logging, mock_logger, mock_response_receiver): - mock_logger.isEnabledFor.return_value = True + def test_json_request(self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver): + caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name) app = _get_test_app() init_app(app) with app.test_client() as client: client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE}) - assert mock_logger.debug.call_count == 1 - call_args = mock_logger.debug.call_args[0] - assert "Received Request" in call_args[0] - assert "Request Body" in call_args[0] - assert call_args[1] == "POST" - assert call_args[2] == "/" - assert _KEY_NEEDLE in call_args[3] + debug_records = _captured_records(caplog, logging.DEBUG) + assert len(debug_records) == 1 + record = debug_records[0] + assert "Received Request" in record.msg + assert "Request Body" in record.msg + assert record.args[0] == "POST" + assert record.args[1] == "/" + assert _KEY_NEEDLE in record.args[2] @pytest.mark.usefixtures("enable_request_logging") - def test_json_request_with_empty_body(self, enable_request_logging, mock_logger, mock_response_receiver): - mock_logger.isEnabledFor.return_value = True + def test_json_request_with_empty_body( + self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver + ): + caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name) app = _get_test_app() init_app(app) with app.test_client() as client: client.post("/", headers={"Content-Type": "application/json"}) - assert mock_logger.debug.call_count == 1 - call_args = mock_logger.debug.call_args[0] - assert "Received Request" in call_args[0] - assert "Request Body" not in call_args[0] - assert call_args[1] == "POST" - assert call_args[2] == "/" + debug_records = _captured_records(caplog, logging.DEBUG) + assert len(debug_records) == 1 + record = debug_records[0] + assert "Received Request" in record.msg + assert "Request Body" not in record.msg + assert record.args == ("POST", "/") @pytest.mark.usefixtures("enable_request_logging") - def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver): - mock_logger.isEnabledFor.return_value = True + def test_json_request_with_invalid_json_as_body( + self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver + ): + caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name) app = _get_test_app() init_app(app) @@ -178,50 +186,53 @@ class TestRequestReceiverLogging: headers={"Content-Type": "application/json"}, data="{", ) - assert mock_logger.debug.call_count == 0 - assert mock_logger.exception.call_count == 1 - - exception_call_args = mock_logger.exception.call_args[0] - assert exception_call_args[0] == "Failed to parse JSON request" + assert not _captured_records(caplog, logging.DEBUG) + error_records = _captured_records(caplog, logging.ERROR) + assert len(error_records) == 1 + assert error_records[0].message == "Failed to parse JSON request" class TestResponseReceiverLogging: @pytest.mark.usefixtures("enable_request_logging") - def test_non_json_response(self, enable_request_logging, mock_logger): - mock_logger.isEnabledFor.return_value = True + def test_non_json_response(self, enable_request_logging, caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name) app = _get_test_app() response = Response( "OK", headers={"Content-Type": "text/plain"}, ) _log_request_finished(app, response) - assert mock_logger.debug.call_count == 1 - call_args = mock_logger.debug.call_args[0] - assert "Response" in call_args[0] - assert "200" in call_args[1] - assert call_args[2] == "text/plain" - assert "Response Body" not in call_args[0] + debug_records = _captured_records(caplog, logging.DEBUG) + assert len(debug_records) == 1 + record = debug_records[0] + assert "Response" in record.msg + assert "200" in record.args[0] + assert record.args[1] == "text/plain" + assert "Response Body" not in record.msg @pytest.mark.usefixtures("enable_request_logging") - def test_json_response(self, enable_request_logging, mock_logger, mock_response_receiver): - mock_logger.isEnabledFor.return_value = True + def test_json_response(self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver): + caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name) app = _get_test_app() response = Response( json.dumps({_KEY_NEEDLE: _VALUE_NEEDLE}), headers={"Content-Type": "application/json"}, ) _log_request_finished(app, response) - assert mock_logger.debug.call_count == 1 - call_args = mock_logger.debug.call_args[0] - assert "Response" in call_args[0] - assert "Response Body" in call_args[0] - assert "200" in call_args[1] - assert call_args[2] == "application/json" - assert _KEY_NEEDLE in call_args[3] + debug_records = _captured_records(caplog, logging.DEBUG) + assert len(debug_records) == 1 + record = debug_records[0] + assert "Response" in record.msg + assert "Response Body" in record.msg + assert "200" in record.args[0] + assert record.args[1] == "application/json" + assert _KEY_NEEDLE in record.args[2] @pytest.mark.usefixtures("enable_request_logging") - def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver): - mock_logger.isEnabledFor.return_value = True + def test_json_request_with_invalid_json_as_body( + self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver + ): + caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name) app = _get_test_app() response = Response( @@ -229,11 +240,10 @@ class TestResponseReceiverLogging: headers={"Content-Type": "application/json"}, ) _log_request_finished(app, response) - assert mock_logger.debug.call_count == 0 - assert mock_logger.exception.call_count == 1 - - exception_call_args = mock_logger.exception.call_args[0] - assert exception_call_args[0] == "Failed to parse JSON response" + assert not _captured_records(caplog, logging.DEBUG) + error_records = _captured_records(caplog, logging.ERROR) + assert len(error_records) == 1 + assert error_records[0].message == "Failed to parse JSON response" class TestResponseUnmodified: @@ -267,7 +277,7 @@ class TestResponseUnmodified: class TestRequestFinishedInfoAccessLine: def test_info_access_log_includes_method_path_status_duration_trace_id( - self, monkeypatch: pytest.MonkeyPatch, caplog + self, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture ): """Ensure INFO access line contains expected fields with computed duration and trace id.""" app = _get_test_app() diff --git a/api/tests/unit_tests/fields/test_message_fields.py b/api/tests/unit_tests/fields/test_message_fields.py new file mode 100644 index 00000000000..8a4eadaf744 --- /dev/null +++ b/api/tests/unit_tests/fields/test_message_fields.py @@ -0,0 +1,36 @@ +from fields.message_fields import ExploreMessageListItem, MessageListItem + + +def _base_kwargs(): + return { + "id": "m1", + "conversation_id": "c1", + "inputs": {}, + "query": "hi", + "answer": "answer", + "retriever_resources": [], + "agent_thoughts": [], + "message_files": [], + "status": "normal", + "extra_contents": [], + } + + +class TestExploreMessageListItem: + def test_exposes_metadata_for_history_rehydration(self): + # The Explore/installed-app surface must surface message_metadata (incl. reasoning) + # so the chat-with-history client can rehydrate the thinking panel on reload. + item = ExploreMessageListItem(**_base_kwargs(), metadata={"reasoning": {"llm": "thinking..."}}) + + payload = item.model_dump(mode="json") + + assert payload["metadata"] == {"reasoning": {"llm": "thinking..."}} + + def test_metadata_defaults_to_none(self): + item = ExploreMessageListItem(**_base_kwargs()) + assert item.model_dump(mode="json")["metadata"] is None + + def test_base_message_list_item_has_no_metadata(self): + # Guard the public service-API contract: the base item must not leak metadata. + payload = MessageListItem(**_base_kwargs()).model_dump(mode="json") + assert "metadata" not in payload diff --git a/api/tests/unit_tests/services/agent/test_agent_composer_entities.py b/api/tests/unit_tests/services/agent/test_agent_composer_entities.py index e82ba92029b..23988c2ec20 100644 --- a/api/tests/unit_tests/services/agent/test_agent_composer_entities.py +++ b/api/tests/unit_tests/services/agent/test_agent_composer_entities.py @@ -105,6 +105,28 @@ def test_agent_app_soul_allows_app_features_and_variables(): assert payload.agent_soul.app_variables[0].name == "company_name" +def test_composer_save_payload_accepts_new_roster_metadata(): + payload = ComposerSavePayload.model_validate( + { + "variant": ComposerVariant.WORKFLOW, + "save_strategy": ComposerSaveStrategy.SAVE_TO_ROSTER, + "new_agent_name": "Research Agent", + "description": "Finds relevant sources.", + "role": "Research Assistant", + "icon_type": "emoji", + "icon": "search", + "icon_background": "#E0F2FE", + } + ) + + assert payload.new_agent_name == "Research Agent" + assert payload.description == "Finds relevant sources." + assert payload.role == "Research Assistant" + assert payload.icon_type == "emoji" + assert payload.icon == "search" + assert payload.icon_background == "#E0F2FE" + + def test_knowledge_query_mode_uses_stable_backend_enums(): config = AgentSoulConfig.model_validate( { diff --git a/api/tests/unit_tests/services/agent/test_agent_services.py b/api/tests/unit_tests/services/agent/test_agent_services.py index 9ba62d60375..1bac183c39f 100644 --- a/api/tests/unit_tests/services/agent/test_agent_services.py +++ b/api/tests/unit_tests/services/agent/test_agent_services.py @@ -3,6 +3,7 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest +from sqlalchemy.exc import IntegrityError from core.workflow.nodes.agent_v2.validators import WorkflowAgentNodeValidationError from models.agent import ( @@ -10,6 +11,7 @@ from models.agent import ( AgentConfigRevisionOperation, AgentConfigSnapshot, AgentDebugConversation, + AgentDriveFile, AgentKind, AgentScope, AgentSource, @@ -31,7 +33,12 @@ from services.agent import composer_service, roster_service from services.agent.agent_soul_state import agent_soul_has_model from services.agent.composer_service import AgentComposerService from services.agent.composer_validator import ComposerConfigValidator -from services.agent.errors import InvalidComposerConfigError +from services.agent.errors import ( + AgentNameConflictError, + AgentNotFoundError, + AgentVersionConflictError, + InvalidComposerConfigError, +) from services.agent.roster_service import AgentRosterService from services.agent.workflow_publish_service import WorkflowAgentPublishService from services.app_service import AppListParams, AppService @@ -415,9 +422,34 @@ def test_composer_save_helpers_create_and_rebind_agents(monkeypatch: pytest.Monk fake_session = FakeSession() monkeypatch.setattr(composer_service.db, "session", fake_session) workflow_agent = SimpleNamespace(id="inline-agent-1", active_config_snapshot_id="inline-version-1") - roster_agent = SimpleNamespace(id="roster-agent-1", active_config_snapshot_id="roster-version-1", name="Roster") + roster_agent = SimpleNamespace( + id="roster-agent-1", + active_config_snapshot_id="roster-version-1", + name="Roster", + description="Source description", + role="Source role", + icon_type="emoji", + icon="source", + icon_background="#FFFFFF", + ) + create_roster_calls = [] + copy_drive_calls = [] monkeypatch.setattr(AgentComposerService, "_create_workflow_only_agent", lambda **kwargs: workflow_agent) - monkeypatch.setattr(AgentComposerService, "_create_roster_agent_for_composer", lambda **kwargs: roster_agent) + + def fake_create_roster_agent_for_composer(**kwargs): + create_roster_calls.append(kwargs) + return roster_agent + + monkeypatch.setattr( + AgentComposerService, + "_create_roster_agent_for_composer", + fake_create_roster_agent_for_composer, + ) + monkeypatch.setattr( + AgentComposerService, + "_copy_agent_drive_rows", + lambda **kwargs: copy_drive_calls.append(kwargs), + ) monkeypatch.setattr(AgentComposerService, "_require_agent", lambda **kwargs: roster_agent) monkeypatch.setattr( AgentComposerService, @@ -443,6 +475,11 @@ def test_composer_save_helpers_create_and_rebind_agents(monkeypatch: pytest.Monk "agent_soul": {"prompt": {"system_prompt": "new"}}, "node_job": {"workflow_prompt": "use prior output"}, "new_agent_name": "Copied Agent", + "description": "Copied description", + "role": "Copied role", + "icon_type": "emoji", + "icon": "copied", + "icon_background": "#E0F2FE", } ) existing_binding = WorkflowAgentNodeBinding(agent_id="inline-agent-1", current_snapshot_id="inline-version-1") @@ -500,6 +537,24 @@ def test_composer_save_helpers_create_and_rebind_agents(monkeypatch: pytest.Monk assert new_agent_binding.binding_type == WorkflowAgentBindingType.ROSTER_AGENT assert save_to_roster_binding.agent_id == "roster-agent-1" assert new_version_binding.current_snapshot_id == "new-version-1" + assert create_roster_calls[0]["description"] == "Copied description" + assert create_roster_calls[0]["role"] == "Copied role" + assert create_roster_calls[0]["icon"] == "copied" + assert create_roster_calls[0]["icon_background"] == "#E0F2FE" + assert create_roster_calls[1]["description"] == "Copied description" + assert create_roster_calls[1]["role"] == "Copied role" + assert create_roster_calls[1]["icon"] == "copied" + assert create_roster_calls[1]["icon_background"] == "#E0F2FE" + assert copy_drive_calls == [ + { + "tenant_id": "tenant-1", + "source_agent_id": "roster-agent-1", + "target_agent_id": "roster-agent-1", + "account_id": "account-1", + "agent_soul": payload.agent_soul, + "node_job": payload.node_job, + } + ] def test_node_job_only_updates_inline_agent_soul(monkeypatch: pytest.MonkeyPatch): @@ -715,9 +770,464 @@ def test_node_job_only_rejects_inline_binding_pointing_to_roster_agent(monkeypat ) +def test_copy_workflow_composer_from_roster_creates_inline_agent_and_preserves_node_job( + monkeypatch: pytest.MonkeyPatch, +): + fake_session = FakeSession() + monkeypatch.setattr(composer_service.db, "session", fake_session) + workflow = SimpleNamespace(id="workflow-1") + node_job = WorkflowNodeJobConfig(workflow_prompt="keep this node task") + binding = WorkflowAgentNodeBinding( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + workflow_version="draft", + node_id="node-1", + binding_type=WorkflowAgentBindingType.ROSTER_AGENT, + agent_id="roster-agent-1", + current_snapshot_id="old-roster-version", + node_job_config=node_job, + ) + roster_agent = Agent( + id="roster-agent-1", + tenant_id="tenant-1", + name="Nadia", + description="Clarification Drafter", + role="Clarifies tenders", + scope=AgentScope.ROSTER, + source=AgentSource.AGENT_APP, + status=AgentStatus.ACTIVE, + active_config_snapshot_id="roster-version-2", + ) + source_version = AgentConfigSnapshot( + id="roster-version-2", + tenant_id="tenant-1", + agent_id="roster-agent-1", + version=2, + config_snapshot='{"prompt":{"system_prompt":"copy me"}}', + ) + inline_agent = Agent( + id="inline-agent-1", + tenant_id="tenant-1", + name="Nadia", + description="Clarification Drafter", + role="Clarifies tenders", + scope=AgentScope.WORKFLOW_ONLY, + source=AgentSource.WORKFLOW, + status=AgentStatus.ACTIVE, + active_config_snapshot_id="inline-version-1", + ) + captured: dict[str, object] = {} + + monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: workflow) + monkeypatch.setattr(AgentComposerService, "_get_workflow_binding", lambda **kwargs: binding) + monkeypatch.setattr(AgentComposerService, "_require_agent", lambda **kwargs: roster_agent) + monkeypatch.setattr(AgentComposerService, "_require_version", lambda **kwargs: source_version) + + def fake_create_workflow_only_agent(**kwargs): + captured["create"] = kwargs + return inline_agent + + def fake_copy_drive_rows(**kwargs): + captured["drive"] = kwargs + + monkeypatch.setattr(AgentComposerService, "_create_workflow_only_agent", fake_create_workflow_only_agent) + monkeypatch.setattr(AgentComposerService, "_copy_agent_drive_rows", fake_copy_drive_rows) + monkeypatch.setattr( + AgentComposerService, + "_serialize_workflow_state", + lambda **kwargs: { + "binding": { + "binding_type": kwargs["binding"].binding_type.value, + "agent_id": kwargs["binding"].agent_id, + "current_snapshot_id": kwargs["binding"].current_snapshot_id, + }, + "node_job": kwargs["binding"].node_job_config_dict, + }, + ) + + state = AgentComposerService.copy_workflow_composer_from_roster( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + account_id="account-1", + source_agent_id="roster-agent-1", + source_snapshot_id="roster-version-2", + ) + + assert state["binding"]["binding_type"] == WorkflowAgentBindingType.INLINE_AGENT.value + assert state["binding"]["agent_id"] == "inline-agent-1" + assert state["node_job"]["workflow_prompt"] == "keep this node task" + assert binding.node_job_config is node_job + create_kwargs = captured["create"] + assert create_kwargs["agent_soul"].prompt.system_prompt == "copy me" + assert create_kwargs["name"] == "Nadia" + assert create_kwargs["role"] == "Clarifies tenders" + drive_kwargs = captured["drive"] + assert drive_kwargs["source_agent_id"] == "roster-agent-1" + assert drive_kwargs["target_agent_id"] == "inline-agent-1" + assert fake_session.commits == 1 + + +def test_copy_workflow_composer_from_roster_rejects_stale_source_snapshot(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: SimpleNamespace(id="workflow-1")) + monkeypatch.setattr( + AgentComposerService, + "_get_workflow_binding", + lambda **kwargs: WorkflowAgentNodeBinding( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + workflow_version="draft", + node_id="node-1", + binding_type=WorkflowAgentBindingType.ROSTER_AGENT, + agent_id="roster-agent-1", + current_snapshot_id="roster-version-1", + node_job_config=WorkflowNodeJobConfig(), + ), + ) + roster_agent = Agent( + id="roster-agent-1", + tenant_id="tenant-1", + name="Nadia", + scope=AgentScope.ROSTER, + source=AgentSource.AGENT_APP, + status=AgentStatus.ACTIVE, + active_config_snapshot_id="roster-version-2", + ) + source_version = AgentConfigSnapshot( + id="roster-version-2", + tenant_id="tenant-1", + agent_id="roster-agent-1", + version=2, + config_snapshot='{"prompt":{"system_prompt":"copy me"}}', + ) + monkeypatch.setattr(AgentComposerService, "_require_agent", lambda **kwargs: roster_agent) + monkeypatch.setattr(AgentComposerService, "_require_version", lambda **kwargs: source_version) + + with pytest.raises(AgentVersionConflictError): + AgentComposerService.copy_workflow_composer_from_roster( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + account_id="account-1", + source_agent_id="roster-agent-1", + source_snapshot_id="roster-version-1", + ) + + +def test_copy_workflow_composer_from_roster_is_idempotent_when_already_inline(monkeypatch: pytest.MonkeyPatch): + inline_binding = WorkflowAgentNodeBinding( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + workflow_version="draft", + node_id="node-1", + binding_type=WorkflowAgentBindingType.INLINE_AGENT, + agent_id="inline-agent-1", + current_snapshot_id="inline-version-1", + ) + inline_agent = Agent( + id="inline-agent-1", + tenant_id="tenant-1", + name="Inline", + scope=AgentScope.WORKFLOW_ONLY, + source=AgentSource.WORKFLOW, + status=AgentStatus.ACTIVE, + active_config_snapshot_id="inline-version-1", + ) + inline_version = AgentConfigSnapshot( + id="inline-version-1", + tenant_id="tenant-1", + agent_id="inline-agent-1", + version=1, + config_snapshot='{"prompt":{"system_prompt":"inline"}}', + ) + monkeypatch.setattr(composer_service.db, "session", FakeSession()) + monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: SimpleNamespace(id="workflow-1")) + monkeypatch.setattr(AgentComposerService, "_get_workflow_binding", lambda **kwargs: inline_binding) + monkeypatch.setattr(AgentComposerService, "_get_agent_if_present", lambda **kwargs: inline_agent) + monkeypatch.setattr(AgentComposerService, "_get_version_if_present", lambda **kwargs: inline_version) + monkeypatch.setattr( + AgentComposerService, + "_serialize_workflow_state", + lambda **kwargs: {"binding_type": kwargs["binding"].binding_type.value}, + ) + + state = AgentComposerService.copy_workflow_composer_from_roster( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + account_id="account-1", + source_agent_id="roster-agent-1", + idempotency_key="same-click", + ) + + assert state == {"binding_type": WorkflowAgentBindingType.INLINE_AGENT.value} + + +@pytest.mark.parametrize( + ("binding_agent_id", "binding_type", "source_scope", "source_status", "expected_message"), + [ + ( + "roster-agent-1", + WorkflowAgentBindingType.INLINE_AGENT, + AgentScope.ROSTER, + AgentStatus.ACTIVE, + "must be bound to a roster agent", + ), + ( + "other-agent", + WorkflowAgentBindingType.ROSTER_AGENT, + AgentScope.ROSTER, + AgentStatus.ACTIVE, + "does not match", + ), + ( + "roster-agent-1", + WorkflowAgentBindingType.ROSTER_AGENT, + AgentScope.WORKFLOW_ONLY, + AgentStatus.ACTIVE, + "must be an active roster agent", + ), + ( + "roster-agent-1", + WorkflowAgentBindingType.ROSTER_AGENT, + AgentScope.ROSTER, + AgentStatus.ARCHIVED, + "must be an active roster agent", + ), + ], +) +def test_copy_workflow_composer_from_roster_rejects_invalid_source_binding( + monkeypatch: pytest.MonkeyPatch, + binding_agent_id: str, + binding_type: WorkflowAgentBindingType, + source_scope: AgentScope, + source_status: AgentStatus, + expected_message: str, +): + binding = WorkflowAgentNodeBinding( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="workflow-1", + workflow_version="draft", + node_id="node-1", + binding_type=binding_type, + agent_id=binding_agent_id, + current_snapshot_id="version-1", + node_job_config=WorkflowNodeJobConfig(), + ) + source_agent = Agent( + id="roster-agent-1", + tenant_id="tenant-1", + name="Source", + scope=source_scope, + source=AgentSource.AGENT_APP, + status=source_status, + active_config_snapshot_id="version-1", + ) + monkeypatch.setattr(AgentComposerService, "_get_draft_workflow", lambda **kwargs: SimpleNamespace(id="workflow-1")) + monkeypatch.setattr(AgentComposerService, "_get_workflow_binding", lambda **kwargs: binding) + monkeypatch.setattr(AgentComposerService, "_require_agent", lambda **kwargs: source_agent) + + with pytest.raises(InvalidComposerConfigError, match=expected_message): + AgentComposerService.copy_workflow_composer_from_roster( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + account_id="account-1", + source_agent_id="roster-agent-1", + ) + + +def test_copy_agent_drive_rows_copies_skill_prefix_and_files(monkeypatch: pytest.MonkeyPatch): + skill_row = AgentDriveFile( + tenant_id="tenant-1", + agent_id="roster-agent-1", + key="tender-analyzer/SKILL.md", + file_kind="tool_file", + file_id="tool-file-1", + value_owned_by_drive=True, + is_skill=True, + skill_metadata='{"name":"Tender Analyzer"}', + size=10, + mime_type="text/markdown", + ) + script_row = AgentDriveFile( + tenant_id="tenant-1", + agent_id="roster-agent-1", + key="tender-analyzer/scripts/run.sh", + file_kind="tool_file", + file_id="tool-file-2", + value_owned_by_drive=True, + size=20, + mime_type="text/x-shellscript", + ) + file_row = AgentDriveFile( + tenant_id="tenant-1", + agent_id="roster-agent-1", + key="files/qna.pdf", + file_kind="upload_file", + file_id="upload-file-1", + value_owned_by_drive=False, + size=30, + mime_type="application/pdf", + ) + fake_session = FakeSession(scalars=[[skill_row, script_row, file_row], []]) + monkeypatch.setattr(composer_service.db, "session", fake_session) + agent_soul = AgentSoulConfig.model_validate( + { + "prompt": { + "system_prompt": "[§skill:tender-analyzer/SKILL.md:Tender Analyzer§]", + }, + } + ) + node_job = WorkflowNodeJobConfig.model_validate( + {"metadata": {"file_refs": [{"name": "qna.pdf", "drive_key": "files/qna.pdf"}]}} + ) + + AgentComposerService._copy_agent_drive_rows( + tenant_id="tenant-1", + source_agent_id="roster-agent-1", + target_agent_id="inline-agent-1", + account_id="account-1", + agent_soul=agent_soul, + node_job=node_job, + ) + + copied = [row for row in fake_session.added if isinstance(row, AgentDriveFile)] + assert [row.key for row in copied] == [ + "tender-analyzer/SKILL.md", + "tender-analyzer/scripts/run.sh", + "files/qna.pdf", + ] + assert {row.agent_id for row in copied} == {"inline-agent-1"} + assert copied[0].file_id == "tool-file-1" + assert copied[0].is_skill is True + assert copied[2].value_owned_by_drive is False + + +def test_copy_agent_drive_rows_skips_when_no_referenced_drive_keys(monkeypatch: pytest.MonkeyPatch): + fake_session = FakeSession() + monkeypatch.setattr(composer_service.db, "session", fake_session) + agent_soul = AgentSoulConfig.model_validate({"prompt": {"system_prompt": "No drive mentions."}}) + + AgentComposerService._copy_agent_drive_rows( + tenant_id="tenant-1", + source_agent_id="roster-agent-1", + target_agent_id="inline-agent-1", + account_id="account-1", + agent_soul=agent_soul, + ) + + assert fake_session.added == [] + + +def test_copy_agent_drive_rows_skips_existing_target_keys(monkeypatch: pytest.MonkeyPatch): + source_row = AgentDriveFile( + tenant_id="tenant-1", + agent_id="roster-agent-1", + key="files/qna.pdf", + file_kind="upload_file", + file_id="upload-file-1", + value_owned_by_drive=False, + size=30, + mime_type="application/pdf", + ) + fake_session = FakeSession(scalars=[[source_row], ["files/qna.pdf"]]) + monkeypatch.setattr(composer_service.db, "session", fake_session) + agent_soul = AgentSoulConfig.model_validate({"prompt": {"system_prompt": "[§file:files/qna.pdf:qna.pdf§]"}}) + + AgentComposerService._copy_agent_drive_rows( + tenant_id="tenant-1", + source_agent_id="roster-agent-1", + target_agent_id="inline-agent-1", + account_id="account-1", + agent_soul=agent_soul, + ) + + assert [row for row in fake_session.added if isinstance(row, AgentDriveFile)] == [] + + +def test_drive_copy_scopes_include_declared_output_benchmark_files(): + agent_soul = AgentSoulConfig.model_validate( + { + "prompt": { + "system_prompt": ( + "[§file:files/source.pdf:source.pdf§] " + "[§knowledge:dataset-1:Docs§] " + "[§skill:tender-analyzer/SKILL.md:Tender Analyzer§]" + ) + }, + } + ) + node_job = WorkflowNodeJobConfig.model_validate( + { + "declared_outputs": [ + { + "name": "qna_report", + "type": "file", + "check": { + "enabled": True, + "prompt": "Compare the generated file with the benchmark.", + "benchmark_file_ref": {"name": "expected.pdf", "drive_key": "files/expected.pdf"}, + }, + }, + { + "name": "summary", + "type": "string", + "check": {"enabled": False, "benchmark_file_ref": {"drive_key": "files/ignored.pdf"}}, + }, + ], + } + ) + + exact_keys, prefixes = AgentComposerService._drive_copy_scopes_from_agent_configs( + agent_soul=agent_soul, + node_job=node_job, + ) + + assert exact_keys == {"files/source.pdf", "files/expected.pdf"} + assert prefixes == {"tender-analyzer/"} + + def test_composer_create_agents_syncs_active_config_has_model(monkeypatch: pytest.MonkeyPatch): fake_session = FakeSession() monkeypatch.setattr(composer_service.db, "session", fake_session) + created_apps = [] + backing_agent = Agent( + id="roster-agent-1", + tenant_id="tenant-1", + name="Ready Agent", + scope=AgentScope.ROSTER, + source=AgentSource.AGENT_APP, + app_id="app-agent-1", + active_config_snapshot_id="empty-version-1", + ) + + class FakeAppService: + def create_app(self, tenant_id, params, account): + created_apps.append((tenant_id, params, account)) + return SimpleNamespace(id="app-agent-1") + + class FakeAgentRosterService: + def __init__(self, session): + self.session = session + + def get_app_backing_agent(self, *, tenant_id, app_id): + assert tenant_id == "tenant-1" + assert app_id == "app-agent-1" + return backing_agent + + monkeypatch.setattr(composer_service, "AppService", FakeAppService) + monkeypatch.setattr(composer_service, "AgentRosterService", FakeAgentRosterService) + monkeypatch.setattr(AgentComposerService, "_require_account", lambda **kwargs: SimpleNamespace(id="account-1")) + monkeypatch.setattr( + AgentComposerService, + "_require_version", + lambda **kwargs: SimpleNamespace(id="empty-version-1", tenant_id="tenant-1", agent_id="roster-agent-1"), + ) monkeypatch.setattr( AgentComposerService, "_create_config_version", @@ -745,6 +1255,81 @@ def test_composer_create_agents_syncs_active_config_has_model(monkeypatch: pytes assert workflow_agent.active_config_has_model is True assert roster_agent.active_config_snapshot_id == "version-with-model" assert roster_agent.active_config_has_model is True + assert roster_agent.source == AgentSource.AGENT_APP + assert roster_agent.app_id == "app-agent-1" + created_tenant_id, created_params, created_account = created_apps[0] + assert created_tenant_id == "tenant-1" + assert created_params.mode == "agent" + assert created_params.name == "Ready Agent" + assert created_account.id == "account-1" + + +def test_composer_require_account(monkeypatch: pytest.MonkeyPatch): + account = SimpleNamespace(id="account-1") + monkeypatch.setattr(composer_service.db, "session", SimpleNamespace(get=lambda model, account_id: account)) + + assert AgentComposerService._require_account(account_id="account-1") is account + + +def test_composer_require_account_raises_when_missing(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(composer_service.db, "session", SimpleNamespace(get=lambda model, account_id: None)) + + with pytest.raises(ValueError, match="Account not found"): + AgentComposerService._require_account(account_id="missing-account") + + +def test_composer_create_roster_agent_rolls_back_name_conflict(monkeypatch: pytest.MonkeyPatch): + fake_session = FakeSession() + monkeypatch.setattr(composer_service.db, "session", fake_session) + + class FakeAppService: + def create_app(self, tenant_id, params, account): + raise IntegrityError("insert apps", params, Exception("duplicate")) + + monkeypatch.setattr(composer_service, "AppService", FakeAppService) + monkeypatch.setattr(AgentComposerService, "_require_account", lambda **kwargs: SimpleNamespace(id="account-1")) + + with pytest.raises(AgentNameConflictError): + AgentComposerService._create_roster_agent_for_composer( + tenant_id="tenant-1", + account_id="account-1", + name="Duplicate Agent", + agent_soul=_agent_soul_with_model(), + operation=AgentConfigRevisionOperation.CREATE_VERSION, + version_note=None, + ) + + assert fake_session.rollbacks == 1 + + +def test_composer_create_roster_agent_raises_when_backing_agent_missing(monkeypatch: pytest.MonkeyPatch): + fake_session = FakeSession() + monkeypatch.setattr(composer_service.db, "session", fake_session) + + class FakeAppService: + def create_app(self, tenant_id, params, account): + return SimpleNamespace(id="app-agent-1") + + class FakeAgentRosterService: + def __init__(self, session): + self.session = session + + def get_app_backing_agent(self, *, tenant_id, app_id): + return None + + monkeypatch.setattr(composer_service, "AppService", FakeAppService) + monkeypatch.setattr(composer_service, "AgentRosterService", FakeAgentRosterService) + monkeypatch.setattr(AgentComposerService, "_require_account", lambda **kwargs: SimpleNamespace(id="account-1")) + + with pytest.raises(AgentNotFoundError): + AgentComposerService._create_roster_agent_for_composer( + tenant_id="tenant-1", + account_id="account-1", + name="Missing Backing Agent", + agent_soul=_agent_soul_with_model(), + operation=AgentConfigRevisionOperation.CREATE_VERSION, + version_note=None, + ) def test_composer_version_helpers_and_lookup_errors(monkeypatch: pytest.MonkeyPatch): @@ -1318,6 +1903,7 @@ def test_agent_app_visible_versions_exclude_draft_saves(): assert agent_app_operations == { AgentConfigRevisionOperation.SAVE_NEW_VERSION, + AgentConfigRevisionOperation.SAVE_TO_ROSTER, AgentConfigRevisionOperation.RESTORE_VERSION, } assert AgentConfigRevisionOperation.SAVE_CURRENT_VERSION not in agent_app_operations diff --git a/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py b/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py index 759d9079348..b0db3bc248b 100644 --- a/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py +++ b/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py @@ -5,6 +5,7 @@ This module covers the pre-uninstall plugin hook behavior: - API failure: soft-fail (logs and does not re-raise) """ +import logging from unittest.mock import patch import pytest @@ -43,18 +44,16 @@ class TestTryPreUninstallPlugin: timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, ) - def test_try_pre_uninstall_plugin_http_error_soft_fails(self): + def test_try_pre_uninstall_plugin_http_error_soft_fails(self, caplog: pytest.LogCaptureFixture): body = PreUninstallPluginRequest( tenant_id="tenant-456", plugin_unique_identifier="com.example.other_plugin", ) + caplog.set_level(logging.ERROR, logger="services.enterprise.plugin_manager_service") - with ( - patch( - "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" - ) as mock_send_request, - patch("services.enterprise.plugin_manager_service.logger") as mock_logger, - ): + with patch( + "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" + ) as mock_send_request: mock_send_request.side_effect = HTTPStatusError( "502 Bad Gateway", request=None, @@ -69,20 +68,22 @@ class TestTryPreUninstallPlugin: json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"}, timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, ) - mock_logger.exception.assert_called_once() + assert len(caplog.records) == 1 + assert caplog.messages[0] == ( + "failed to perform pre uninstall plugin hook. tenant_id: tenant-456, " + "plugin_unique_identifier: com.example.other_plugin" + ) - def test_try_pre_uninstall_plugin_generic_exception_soft_fails(self): + def test_try_pre_uninstall_plugin_generic_exception_soft_fails(self, caplog: pytest.LogCaptureFixture): body = PreUninstallPluginRequest( tenant_id="tenant-789", plugin_unique_identifier="com.example.failing_plugin", ) + caplog.set_level(logging.ERROR, logger="services.enterprise.plugin_manager_service") - with ( - patch( - "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" - ) as mock_send_request, - patch("services.enterprise.plugin_manager_service.logger") as mock_logger, - ): + with patch( + "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" + ) as mock_send_request: mock_send_request.side_effect = ConnectionError("network unreachable") PluginManagerService.try_pre_uninstall_plugin(body) @@ -93,7 +94,11 @@ class TestTryPreUninstallPlugin: json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"}, timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, ) - mock_logger.exception.assert_called_once() + assert len(caplog.records) == 1 + assert caplog.messages[0] == ( + "failed to perform pre uninstall plugin hook. tenant_id: tenant-789, " + "plugin_unique_identifier: com.example.failing_plugin" + ) class TestCheckCredentialPolicyCompliance: diff --git a/api/tests/unit_tests/services/enterprise/test_rbac_service.py b/api/tests/unit_tests/services/enterprise/test_rbac_service.py index 5dc68008840..ef786c944e1 100644 --- a/api/tests/unit_tests/services/enterprise/test_rbac_service.py +++ b/api/tests/unit_tests/services/enterprise/test_rbac_service.py @@ -86,21 +86,26 @@ class TestRoles: call = _call_args(mock_send) assert call.method == "GET" assert call.endpoint == "/rbac/roles" - assert call.params == {"page_number": 2, "results_per_page": 50, "reverse": "true"} + assert call.params == { + "dataset_operator_enabled": False, + "page_number": 2, + "results_per_page": 50, + "reverse": "true", + } assert out.pagination assert out.pagination.total_count == 1 def test_list_omits_params_when_default(self, mock_send: MagicMock): mock_send.return_value = {"data": [], "pagination": None} svc.RBACService.Roles.list("tenant-1") - assert _call_args(mock_send).params is None + assert _call_args(mock_send).params is not None def test_list_forwards_include_owner(self, mock_send: MagicMock): mock_send.return_value = {"data": [], "pagination": None} svc.RBACService.Roles.list("tenant-1", include_owner=1) - assert _call_args(mock_send).params == {"include_owner": 1} + assert _call_args(mock_send).params == {"dataset_operator_enabled": False, "include_owner": 1} def test_list_coerces_null_permission_keys(self, mock_send: MagicMock): mock_send.return_value = { diff --git a/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py index f5f27f7296f..e8fcbaf96ad 100644 --- a/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py +++ b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py @@ -42,6 +42,16 @@ class TestBuildInRecommendAppRetrieval: mock_fetch.assert_called_once_with("en-US") assert result == {"apps": []} + @patch("services.recommend_app.buildin.buildin_retrieval.DatabaseRecommendAppRetrieval") + def test_get_learn_dify_apps_delegates_to_database(self, mock_database_retrieval): + expected = {"recommended_apps": [{"id": "learn-dify-app"}]} + mock_database_retrieval.fetch_learn_dify_apps_from_db.return_value = expected + + result = BuildInRecommendAppRetrieval().get_learn_dify_apps("en-US") + + assert result == expected + mock_database_retrieval.fetch_learn_dify_apps_from_db.assert_called_once_with("en-US") + def test_get_recommend_app_detail_delegates(self): with patch.object( BuildInRecommendAppRetrieval, diff --git a/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py index c7b86e5743d..55165deec25 100644 --- a/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py +++ b/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from services.recommend_app.recommend_app_type import RecommendAppType from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval @@ -58,6 +59,32 @@ class TestRemoteRecommendAppRetrieval: result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") assert result == {"recommended_apps": [{"id": "builtin"}]} + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_learn_dify_apps_from_dify_official", + return_value={"recommended_apps": [{"id": "learn-dify-app"}]}, + ) + def test_get_learn_dify_apps_success(self, mock_fetch): + result = RemoteRecommendAppRetrieval().get_learn_dify_apps("en-US") + + assert result == {"recommended_apps": [{"id": "learn-dify-app"}]} + mock_fetch.assert_called_once_with("en-US") + + @patch( + "services.recommend_app.remote.remote_retrieval.DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db", + return_value={"recommended_apps": [{"id": "db-fallback"}]}, + ) + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_learn_dify_apps_from_dify_official", + side_effect=ValueError("server error"), + ) + def test_get_learn_dify_apps_falls_back_to_database_on_error(self, mock_fetch, mock_database): + result = RemoteRecommendAppRetrieval().get_learn_dify_apps("en-US") + + assert result == {"recommended_apps": [{"id": "db-fallback"}]} + mock_database.assert_called_once_with("en-US") + class TestFetchFromDifyOfficial: @patch("services.recommend_app.remote.remote_retrieval.dify_config") @@ -118,3 +145,84 @@ class TestFetchFromDifyOfficial: result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") assert "categories" not in result + assert mock_get.call_args.kwargs["headers"] == {} + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_forwards_request_origin_header(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_config.CONSOLE_WEB_URL = "https://saas.dify.dev" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"recommended_apps": []} + mock_get.return_value = mock_response + + flask_app = Flask(__name__) + with flask_app.test_request_context(headers={"Origin": "https://cloud.example.com"}): + RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + assert mock_get.call_args.kwargs["headers"] == {"Origin": "https://cloud.example.com"} + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_falls_back_to_console_web_url_origin(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_config.CONSOLE_WEB_URL = "https://saas.dify.dev/console" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"recommended_apps": []} + mock_get.return_value = mock_response + + flask_app = Flask(__name__) + with flask_app.test_request_context(): + RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + assert mock_get.call_args.kwargs["headers"] == {"Origin": "https://saas.dify.dev/console"} + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_falls_back_to_console_web_url_without_request_context(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_config.CONSOLE_WEB_URL = "http://localhost:3000/console" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"recommended_apps": []} + mock_get.return_value = mock_response + + RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + assert mock_get.call_args.kwargs["headers"] == {"Origin": "http://localhost:3000/console"} + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_uses_console_web_url_without_scheme(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_config.CONSOLE_WEB_URL = "saas.dify.dev" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"recommended_apps": []} + mock_get.return_value = mock_response + + flask_app = Flask(__name__) + with flask_app.test_request_context(): + RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + assert mock_get.call_args.kwargs["headers"] == {"Origin": "saas.dify.dev"} + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_learn_dify_apps_returns_json_on_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"recommended_apps": [{"id": "learn-dify-app"}]} + mock_get.return_value = mock_response + + result = RemoteRecommendAppRetrieval.fetch_learn_dify_apps_from_dify_official("en-US") + + assert result == {"recommended_apps": [{"id": "learn-dify-app"}]} + assert mock_get.call_args.args[0] == "https://example.com/apps/learn-dify?language=en-US" + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_learn_dify_apps_raises_on_non_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_get.return_value = MagicMock(status_code=500) + + with pytest.raises(ValueError, match="fetch learn dify apps failed"): + RemoteRecommendAppRetrieval.fetch_learn_dify_apps_from_dify_official("en-US") diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py index 628e4e594dd..4768215210e 100644 --- a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py +++ b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py @@ -8,6 +8,7 @@ and both positive and negative test scenarios. import io import json +import logging import zipfile from datetime import datetime from unittest.mock import Mock, create_autospec, patch @@ -312,16 +313,16 @@ class TestGetSchemaVersion: result = restore._get_schema_version(manifest) assert result == "1.0" - def test_missing_schema_version_defaults_to_1_0(self): + def test_missing_schema_version_defaults_to_1_0(self, caplog: pytest.LogCaptureFixture): """Should default to 1.0 when schema_version is missing.""" restore = WorkflowRunRestore() manifest = {"tables": {}} + caplog.set_level(logging.WARNING, logger="services.retention.workflow_run.restore_archived_workflow_run") - with patch("services.retention.workflow_run.restore_archived_workflow_run.logger") as mock_logger: - result = restore._get_schema_version(manifest) + result = restore._get_schema_version(manifest) assert result == "1.0" - mock_logger.warning.assert_called_once_with("Manifest missing schema_version; defaulting to 1.0") + assert "Manifest missing schema_version; defaulting to 1.0" in caplog.messages def test_unsupported_schema_version_raises_error(self): """Should raise ValueError for unsupported schema version.""" @@ -492,19 +493,19 @@ class TestRestoreTableRecords: """Tests for WorkflowRunRestore._restore_table_records method.""" @patch("services.retention.workflow_run.restore_archived_workflow_run.TABLE_MODELS") - def test_unknown_table_returns_zero(self, mock_table_models): + def test_unknown_table_returns_zero(self, mock_table_models, caplog: pytest.LogCaptureFixture): """Should return 0 for unknown table.""" restore = WorkflowRunRestore() mock_table_models.get.return_value = None mock_session = Mock() records = [{"id": "test"}] + caplog.set_level(logging.WARNING, logger="services.retention.workflow_run.restore_archived_workflow_run") - with patch("services.retention.workflow_run.restore_archived_workflow_run.logger") as mock_logger: - result = restore._restore_table_records(mock_session, "unknown_table", records, schema_version="1.0") + result = restore._restore_table_records(mock_session, "unknown_table", records, schema_version="1.0") assert result == 0 - mock_logger.warning.assert_called_once_with("Unknown table: %s", "unknown_table") + assert "Unknown table: unknown_table" in caplog.messages def test_empty_records_returns_zero(self): """Should return 0 for empty records list.""" diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 3b5c6cc9bd6..214344e0e57 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -179,7 +179,7 @@ class TestAccountService: mock_password_dependencies["compare_password"].return_value = True # Execute test - result = AccountService.authenticate("test@example.com", "password") + result = AccountService.authenticate("test@example.com", "password", session=mock_db_dependencies["db"].session) # Verify results assert result == mock_account @@ -191,7 +191,11 @@ class TestAccountService: # Execute test and verify exception self._assert_exception_raised( - AccountPasswordError, AccountService.authenticate, "notfound@example.com", "password" + AccountPasswordError, + AccountService.authenticate, + "notfound@example.com", + "password", + session=mock_db_dependencies["db"].session, ) def test_authenticate_account_banned(self, mock_db_dependencies): @@ -202,7 +206,13 @@ class TestAccountService: mock_db_dependencies["db"].session.scalar.return_value = mock_account # Execute test and verify exception - self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password") + self._assert_exception_raised( + AccountLoginError, + AccountService.authenticate, + "banned@example.com", + "password", + session=mock_db_dependencies["db"].session, + ) def test_authenticate_password_error(self, mock_db_dependencies, mock_password_dependencies): """Test authentication with wrong password.""" @@ -215,7 +225,11 @@ class TestAccountService: # Execute test and verify exception self._assert_exception_raised( - AccountPasswordError, AccountService.authenticate, "test@example.com", "wrongpassword" + AccountPasswordError, + AccountService.authenticate, + "test@example.com", + "wrongpassword", + session=mock_db_dependencies["db"].session, ) def test_authenticate_pending_account_activates(self, mock_db_dependencies, mock_password_dependencies): @@ -228,7 +242,9 @@ class TestAccountService: mock_password_dependencies["compare_password"].return_value = True # Execute test - result = AccountService.authenticate("pending@example.com", "password") + result = AccountService.authenticate( + "pending@example.com", "password", session=mock_db_dependencies["db"].session + ) # Verify results assert result == mock_account @@ -253,6 +269,7 @@ class TestAccountService: interface_language="en-US", password="password123", interface_theme="light", + session=mock_db_dependencies["db"].session, ) # Verify results @@ -290,6 +307,7 @@ class TestAccountService: interface_language="en-US", password="password123", timezone="Asia/Shanghai", + session=mock_db_dependencies["db"].session, ) assert result.timezone == "Asia/Shanghai" @@ -309,6 +327,7 @@ class TestAccountService: email="test@example.com", name="Test User", interface_language="en-US", + session=MagicMock(), ) def test_create_account_email_frozen(self, mock_db_dependencies, mock_external_service_dependencies): @@ -325,6 +344,7 @@ class TestAccountService: email="frozen@example.com", name="Test User", interface_language="en-US", + session=mock_db_dependencies["db"].session, ) dify_config.BILLING_ENABLED = False @@ -341,6 +361,7 @@ class TestAccountService: interface_language="zh-CN", password=None, interface_theme="dark", + session=mock_db_dependencies["db"].session, ) # Verify results @@ -375,7 +396,9 @@ class TestAccountService: mock_password_dependencies["hash_password"].return_value = b"new_hashed_password" # Execute test - result = AccountService.update_account_password(mock_account, "old_password", "new_password123") + result = AccountService.update_account_password( + mock_account, "old_password", "new_password123", session=mock_db_dependencies["db"].session + ) # Verify results assert result == mock_account @@ -391,7 +414,7 @@ class TestAccountService: # Verify database operations self._assert_database_operations_called(mock_db_dependencies["db"]) - def test_update_account_password_current_password_incorrect(self, mock_password_dependencies): + def test_update_account_password_current_password_incorrect(self, mock_db_dependencies, mock_password_dependencies): """Test password update with incorrect current password.""" # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock() @@ -404,6 +427,7 @@ class TestAccountService: mock_account, "wrong_password", "new_password123", + session=mock_db_dependencies["db"].session, ) # Verify password comparison was called @@ -411,7 +435,7 @@ class TestAccountService: "wrong_password", "hashed_password", "salt" ) - def test_update_account_password_invalid_new_password(self, mock_password_dependencies): + def test_update_account_password_invalid_new_password(self, mock_db_dependencies, mock_password_dependencies): """Test password update with invalid new password.""" # Setup test data mock_account = TestAccountAssociatedDataFactory.create_account_mock() @@ -420,7 +444,12 @@ class TestAccountService: # Execute test and verify exception self._assert_exception_raised( - ValueError, AccountService.update_account_password, mock_account, "old_password", "short" + ValueError, + AccountService.update_account_password, + mock_account, + "old_password", + "short", + session=mock_db_dependencies["db"].session, ) # Verify password validation was called @@ -447,19 +476,19 @@ class TestAccountService: mock_datetime.UTC = "UTC" # Execute test - result = AccountService.load_user("user-123") + result = AccountService.load_user("user-123", mock_db_dependencies["db"].session) # Verify results assert result == mock_account assert mock_account.set_tenant_id.called - mock_refresh_last_active.assert_called_once_with(mock_account) + mock_refresh_last_active.assert_called_once_with(mock_account, mock_db_dependencies["db"].session) def test_load_user_not_found(self, mock_db_dependencies): """Test user loading when user does not exist.""" mock_db_dependencies["db"].session.get.return_value = None # Execute test - result = AccountService.load_user("non-existent-user") + result = AccountService.load_user("non-existent-user", mock_db_dependencies["db"].session) # Verify results assert result is None @@ -500,14 +529,14 @@ class TestAccountService: mock_naive_utc_now.return_value = mock_now # Execute test - result = AccountService.load_user("user-123") + result = AccountService.load_user("user-123", mock_db_dependencies["db"].session) # Verify results assert result == mock_account assert mock_available_tenant.current is True assert mock_available_tenant.last_opened_at == mock_now self._assert_database_operations_called(mock_db_dependencies["db"]) - mock_refresh_last_active.assert_called_once_with(mock_account) + mock_refresh_last_active.assert_called_once_with(mock_account, mock_db_dependencies["db"].session) def test_load_user_no_tenants(self, mock_db_dependencies): """Test user loading when user has no tenants at all.""" @@ -525,7 +554,7 @@ class TestAccountService: mock_datetime.UTC = "UTC" # Execute test - result = AccountService.load_user("user-123") + result = AccountService.load_user("user-123", mock_db_dependencies["db"].session) # Verify results assert result is None @@ -542,7 +571,7 @@ class TestAccountService: ): mock_redis_client.set.return_value = True - AccountService._refresh_account_last_active(mock_account) + AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session) mock_redis_client.set.assert_called_once_with( "account_last_active_refresh:user-123", @@ -565,7 +594,7 @@ class TestAccountService: ): mock_redis_client.set.return_value = None - AccountService._refresh_account_last_active(mock_account) + AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session) mock_redis_client.set.assert_called_once_with( "account_last_active_refresh:user-123", @@ -586,7 +615,7 @@ class TestAccountService: patch("services.account_service.naive_utc_now", return_value=now), patch("services.account_service.redis_client") as mock_redis_client, ): - AccountService._refresh_account_last_active(mock_account) + AccountService._refresh_account_last_active(mock_account, mock_db_dependencies["db"].session) mock_redis_client.set.assert_not_called() mock_db_dependencies["db"].session.execute.assert_not_called() @@ -736,7 +765,9 @@ class TestTenantService: mock_credit_pool_db.session.commit = MagicMock() # Execute test - TenantService.create_owner_tenant_if_not_exist(mock_account) + TenantService.create_owner_tenant_if_not_exist( + mock_account, session=mock_db_dependencies["db"].session + ) # Verify tenant was created with correct parameters mock_db_dependencies["db"].session.add.assert_called() @@ -838,7 +869,9 @@ class TestTenantService: mock_sync.return_value = True # Act - TenantService.remove_member_from_tenant(mock_tenant, mock_pending_member, mock_operator) + TenantService.remove_member_from_tenant( + mock_tenant, mock_pending_member, mock_operator, session=mock_db.session + ) # Assert: enterprise sync still receives the correct member ID mock_sync.assert_called_once_with( @@ -878,7 +911,9 @@ class TestTenantService: mock_sync.return_value = True # Act - TenantService.remove_member_from_tenant(mock_tenant, mock_pending_member, mock_operator) + TenantService.remove_member_from_tenant( + mock_tenant, mock_pending_member, mock_operator, session=mock_db.session + ) # Assert: only the join record should be deleted, not the account mock_db.session.delete.assert_called_once_with(mock_ta) @@ -909,7 +944,9 @@ class TestTenantService: mock_sync.return_value = True # Act - TenantService.remove_member_from_tenant(mock_tenant, mock_active_member, mock_operator) + TenantService.remove_member_from_tenant( + mock_tenant, mock_active_member, mock_operator, session=mock_db.session + ) # Assert: only the join record should be deleted mock_db.session.delete.assert_called_once_with(mock_ta) @@ -934,7 +971,7 @@ class TestTenantService: mock_naive_utc_now.return_value = mock_now # Execute test - TenantService.switch_tenant(mock_account, "tenant-456") + TenantService.switch_tenant(mock_account, "tenant-456", session=mock_db.session) # Verify tenant was switched assert mock_tenant_join.current is True @@ -947,7 +984,7 @@ class TestTenantService: mock_account = TestAccountAssociatedDataFactory.create_account_mock() # Execute test and verify exception - self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None) + self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None, session=MagicMock()) # ==================== Role Management Tests ==================== @@ -971,7 +1008,7 @@ class TestTenantService: mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join] # Execute test - TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator) + TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator, session=mock_db.session) # Verify role was updated assert mock_target_join.role == "admin" @@ -1005,7 +1042,9 @@ class TestTenantService: ): mock_db_dependencies["db"].session.scalar.return_value = None - TenantService.create_owner_tenant_if_not_exist(mock_account, is_setup=True) + TenantService.create_owner_tenant_if_not_exist( + mock_account, is_setup=True, session=mock_db_dependencies["db"].session + ) mock_rbac_service.MemberRoles.replace.assert_called_once_with( tenant_id="tenant-rbac", @@ -1030,7 +1069,7 @@ class TestTenantService: with patch("services.account_service.db") as mock_db: mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join] - TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator) + TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator, session=mock_db.session) assert mock_target_join.role == "editor" self._assert_database_operations_called(mock_db) @@ -1052,7 +1091,9 @@ class TestTenantService: mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join] with pytest.raises(NoPermissionError): - TenantService.update_member_role(mock_tenant, mock_member, "editor", mock_operator) + TenantService.update_member_role( + mock_tenant, mock_member, "editor", mock_operator, session=mock_db.session + ) def test_admin_cannot_promote_member_to_owner(self): """Test admin cannot promote a non-owner member to owner.""" @@ -1071,7 +1112,9 @@ class TestTenantService: mock_db.session.scalar.side_effect = [mock_operator_join, mock_target_join, mock_operator_join] with pytest.raises(NoPermissionError): - TenantService.update_member_role(mock_tenant, mock_member, "owner", mock_operator) + TenantService.update_member_role( + mock_tenant, mock_member, "owner", mock_operator, session=mock_db.session + ) # ==================== Permission Check Tests ==================== @@ -1089,7 +1132,9 @@ class TestTenantService: mock_db_dependencies["db"].session.scalar.return_value = mock_operator_join # Execute test - should not raise exception - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "add", session=mock_db_dependencies["db"].session + ) def test_check_member_permission_operate_self(self): """Test member permission check when operator tries to operate self.""" @@ -1108,6 +1153,7 @@ class TestTenantService: mock_operator, mock_operator, # Same as operator "add", + session=MagicMock(), ) def test_admin_can_remove_non_owner_member(self, mock_db_dependencies): @@ -1124,7 +1170,9 @@ class TestTenantService: ) mock_db_dependencies["db"].session.scalar.side_effect = [mock_operator_join, mock_member_join] - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "remove", session=mock_db_dependencies["db"].session + ) def test_admin_cannot_remove_owner_member(self, mock_db_dependencies): """Test admin cannot remove an owner member.""" @@ -1141,7 +1189,9 @@ class TestTenantService: mock_db_dependencies["db"].session.scalar.side_effect = [mock_operator_join, mock_member_join] with pytest.raises(NoPermissionError): - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "remove", session=MagicMock() + ) def test_rbac_member_can_remove_non_owner_member(self): """Test RBAC workspace.member.manage allows removing a non-owner member.""" @@ -1158,7 +1208,9 @@ class TestTenantService: patch("services.account_service.RBACService.MyPermissions.get", return_value=mock_permissions), patch("services.account_service.AccountService.is_rbac_workspace_owner", return_value=False), ): - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "remove", session=MagicMock() + ) def test_rbac_member_cannot_remove_without_permission(self): """Test RBAC permission check rejects removal without workspace.member.manage.""" @@ -1175,7 +1227,9 @@ class TestTenantService: patch("services.account_service.RBACService.MyPermissions.get", return_value=mock_permissions), ): with pytest.raises(NoPermissionError): - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "remove", session=MagicMock() + ) def test_rbac_member_cannot_remove_owner_member(self): """Test RBAC permission check rejects removing an owner member.""" @@ -1193,7 +1247,9 @@ class TestTenantService: patch("services.account_service.AccountService.is_rbac_workspace_owner", return_value=True), ): with pytest.raises(NoPermissionError): - TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "remove") + TenantService.check_member_permission( + mock_tenant, mock_operator, mock_member, "remove", session=MagicMock() + ) def test_get_rbac_workspace_owner_account_id(self): mock_roles = MagicMock() @@ -1304,7 +1360,14 @@ class TestRegisterService: mock_dify_setup.return_value = mock_dify_setup_instance # Execute test - RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1", "en-US") + RegisterService.setup( + "admin@example.com", + "Admin User", + "password123", + "192.168.1.1", + "en-US", + session=mock_db_dependencies["db"].session, + ) # Verify results mock_create_account.assert_called_once_with( @@ -1313,8 +1376,11 @@ class TestRegisterService: interface_language="en-US", password="password123", is_setup=True, + session=mock_db_dependencies["db"].session, + ) + mock_create_tenant.assert_called_once_with( + account=mock_account, is_setup=True, session=mock_db_dependencies["db"].session ) - mock_create_tenant.assert_called_once_with(account=mock_account, is_setup=True) mock_dify_setup.assert_called_once() self._assert_database_operations_called(mock_db_dependencies["db"]) @@ -1337,6 +1403,7 @@ class TestRegisterService: "password123", "192.168.1.1", "en-US", + session=mock_db_dependencies["db"].session, ) # Verify rollback operations were called @@ -1369,10 +1436,13 @@ class TestRegisterService: name="Test User", interface_language="en-US", password=None, + session=mock_db_dependencies["db"].session, ) assert result == mock_account - mock_create_workspace.assert_called_once_with(account=mock_account) + mock_create_workspace.assert_called_once_with( + account=mock_account, session=mock_db_dependencies["db"].session + ) mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled( @@ -1400,9 +1470,12 @@ class TestRegisterService: name="Test User", interface_language="en-US", password=None, + session=mock_db_dependencies["db"].session, ) - mock_create_workspace.assert_called_once_with(account=mock_account) + mock_create_workspace.assert_called_once_with( + account=mock_account, session=mock_db_dependencies["db"].session + ) mock_join_default_workspace.assert_not_called() def test_create_account_and_tenant_still_calls_default_workspace_join_when_workspace_creation_fails( @@ -1433,6 +1506,7 @@ class TestRegisterService: name="Test User", interface_language="en-US", password=None, + session=mock_db_dependencies["db"].session, ) mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) @@ -1470,6 +1544,7 @@ class TestRegisterService: name="Test User", password="password123", language="en-US", + session=mock_db_dependencies["db"].session, ) # Verify results @@ -1483,8 +1558,11 @@ class TestRegisterService: password="password123", is_setup=False, timezone=None, + session=mock_db_dependencies["db"].session, + ) + mock_create_tenant.assert_called_once_with( + "Test User's Workspace", session=mock_db_dependencies["db"].session ) - mock_create_tenant.assert_called_once_with("Test User's Workspace") mock_create_member.assert_called_once_with( mock_tenant, mock_account, mock_db_dependencies["db"].session, role="owner" ) @@ -1516,6 +1594,7 @@ class TestRegisterService: password="password123", language="en-US", create_workspace_required=False, + session=mock_db_dependencies["db"].session, ) assert result == mock_account @@ -1546,6 +1625,7 @@ class TestRegisterService: password="password123", language="en-US", create_workspace_required=False, + session=mock_db_dependencies["db"].session, ) mock_join_default_workspace.assert_not_called() @@ -1584,6 +1664,7 @@ class TestRegisterService: name="Test User", password="password123", language="en-US", + session=mock_db_dependencies["db"].session, ) mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) @@ -1623,6 +1704,7 @@ class TestRegisterService: name="Test User", password="password123", language="en-US", + session=mock_db_dependencies["db"].session, ) mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) @@ -1665,11 +1747,14 @@ class TestRegisterService: open_id="oauth123", provider="google", language="en-US", + session=mock_db_dependencies["db"].session, ) # Verify results assert result == mock_account - mock_link_account.assert_called_once_with("google", "oauth123", mock_account) + mock_link_account.assert_called_once_with( + "google", "oauth123", mock_account, session=mock_db_dependencies["db"].session + ) self._assert_database_operations_called(mock_db_dependencies["db"]) def test_register_with_pending_status(self, mock_db_dependencies, mock_external_service_dependencies): @@ -1707,6 +1792,7 @@ class TestRegisterService: password="password123", language="en-US", status=AccountStatus.PENDING, + session=mock_db_dependencies["db"].session, ) # Verify results @@ -1744,6 +1830,7 @@ class TestRegisterService: name="Test User", password="password123", language="en-US", + session=mock_db_dependencies["db"].session, ) # Verify rollback was called @@ -1767,6 +1854,7 @@ class TestRegisterService: name="Test User", password="password123", language="en-US", + session=mock_db_dependencies["db"].session, ) # Verify rollback was called @@ -1810,6 +1898,7 @@ class TestRegisterService: language="en-US", role="normal", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) # Verify results @@ -1820,8 +1909,9 @@ class TestRegisterService: language="en-US", status=AccountStatus.PENDING, is_setup=True, + session=mock_db_dependencies["db"].session, ) - mock_lookup.assert_called_once_with("newuser@example.com") + mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "newuser@example.com") def test_invite_new_member_normalizes_new_account_email( self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies @@ -1856,6 +1946,7 @@ class TestRegisterService: language="en-US", role="normal", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) mock_register.assert_called_once_with( @@ -1864,13 +1955,22 @@ class TestRegisterService: language="en-US", status=AccountStatus.PENDING, is_setup=True, + session=mock_db_dependencies["db"].session, + ) + mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, mixed_email) + mock_check_permission.assert_called_once_with( + mock_tenant, + mock_inviter, + None, + "add", + session=mock_db_dependencies["db"].session, ) - mock_lookup.assert_called_once_with(mixed_email) - mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, None, "add") mock_create_member.assert_called_once_with( mock_tenant, mock_new_account, mock_db_dependencies["db"].session, "normal" ) - mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) + mock_switch_tenant.assert_called_once_with( + mock_new_account, mock_tenant.id, session=mock_db_dependencies["db"].session + ) mock_generate_token.assert_called_once_with( mock_tenant, mock_new_account, "normal", requires_setup=True ) @@ -1912,6 +2012,7 @@ class TestRegisterService: language="en-US", role="normal", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) # Verify results @@ -1923,7 +2024,7 @@ class TestRegisterService: mock_tenant, mock_existing_account, "normal", requires_setup=True ) mock_task_dependencies.delay.assert_called_once() - mock_lookup.assert_called_once_with("existing@example.com") + mock_lookup.assert_called_once_with(mock_db_dependencies["db"].session, "existing@example.com") def test_invite_existing_active_account_requires_acceptance_before_joining( self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies @@ -1954,10 +2055,17 @@ class TestRegisterService: language="en-US", role="admin", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) assert result == "invite-token-123" - mock_check_permission.assert_called_once_with(mock_tenant, mock_inviter, mock_existing_account, "add") + mock_check_permission.assert_called_once_with( + mock_tenant, + mock_inviter, + mock_existing_account, + "add", + session=mock_db_dependencies["db"].session, + ) mock_create_member.assert_not_called() mock_generate_token.assert_called_once_with( mock_tenant, mock_existing_account, "admin", requires_setup=False @@ -1993,6 +2101,7 @@ class TestRegisterService: language="en-US", role="normal", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) mock_lookup.assert_called_once() @@ -2010,6 +2119,7 @@ class TestRegisterService: language="en-US", role="normal", inviter=None, + session=MagicMock(), ) # ==================== RBAC Member Invitation Tests ==================== @@ -2048,6 +2158,7 @@ class TestRegisterService: language="en-US", role="rbac-role-id-123", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) assert result == "rbac-token" @@ -2093,6 +2204,7 @@ class TestRegisterService: language="en-US", role="rbac-role-id-456", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) assert result == "rbac-token" @@ -2141,6 +2253,7 @@ class TestRegisterService: language="en-US", role="rbac-role-id-456", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) mock_create_member.assert_called_once_with( @@ -2191,6 +2304,7 @@ class TestRegisterService: language="en-US", role="editor", inviter=mock_inviter, + session=mock_db_dependencies["db"].session, ) assert result == "legacy-token" @@ -2300,7 +2414,9 @@ class TestRegisterService: mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, mock_account] # Execute test - result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + result = RegisterService.get_invitation_if_token_valid( + "tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session + ) # Verify results assert result is not None @@ -2314,7 +2430,9 @@ class TestRegisterService: mock_redis_dependencies.get.return_value = None # Execute test - result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + result = RegisterService.get_invitation_if_token_valid( + "tenant-456", "test@example.com", "token-123", session=MagicMock() + ) # Verify results assert result is None @@ -2333,7 +2451,9 @@ class TestRegisterService: mock_db_dependencies["db"].session.scalar.return_value = None # Execute test - result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + result = RegisterService.get_invitation_if_token_valid( + "tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session + ) # Verify results assert result is None @@ -2357,7 +2477,9 @@ class TestRegisterService: mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, None] # Execute test - result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + result = RegisterService.get_invitation_if_token_valid( + "tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session + ) # Verify results assert result is None @@ -2384,7 +2506,9 @@ class TestRegisterService: mock_db_dependencies["db"].session.scalar.side_effect = [mock_tenant, mock_account] # Execute test - result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + result = RegisterService.get_invitation_if_token_valid( + "tenant-456", "test@example.com", "token-123", session=mock_db_dependencies["db"].session + ) # Verify results assert result is None @@ -2395,22 +2519,28 @@ class TestRegisterService: with patch( "services.account_service.RegisterService.get_invitation_if_token_valid", return_value=invitation ) as mock_get: - result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + result = RegisterService.get_invitation_with_case_fallback( + "tenant-456", "User@Test.com", "token-123", session=MagicMock() + ) assert result == invitation - mock_get.assert_called_once_with("tenant-456", "User@Test.com", "token-123") + mock_get.assert_called_once_with( + "tenant-456", "User@Test.com", "token-123", session=mock_get.call_args.kwargs["session"] + ) def test_get_invitation_with_case_fallback_retries_with_lowercase(self): """Fallback helper should retry with lowercase email when needed.""" invitation = {"workspace_id": "tenant-456"} with patch("services.account_service.RegisterService.get_invitation_if_token_valid") as mock_get: mock_get.side_effect = [None, invitation] - result = RegisterService.get_invitation_with_case_fallback("tenant-456", "User@Test.com", "token-123") + result = RegisterService.get_invitation_with_case_fallback( + "tenant-456", "User@Test.com", "token-123", session=MagicMock() + ) assert result == invitation assert mock_get.call_args_list == [ - (("tenant-456", "User@Test.com", "token-123"),), - (("tenant-456", "user@test.com", "token-123"),), + (("tenant-456", "User@Test.com", "token-123"), {"session": mock_get.call_args_list[0].kwargs["session"]}), + (("tenant-456", "user@test.com", "token-123"), {"session": mock_get.call_args_list[1].kwargs["session"]}), ] # ==================== Helper Method Tests ==================== diff --git a/api/tests/unit_tests/services/test_annotation_service.py b/api/tests/unit_tests/services/test_annotation_service.py index 55912cc1c1f..7574c13342a 100644 --- a/api/tests/unit_tests/services/test_annotation_service.py +++ b/api/tests/unit_tests/services/test_annotation_service.py @@ -546,7 +546,7 @@ class TestAppAnnotationServiceDirectManipulation: # Act & Assert with pytest.raises(NotFound): - AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1") + AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1", mock_db.session) def test_update_app_annotation_directly_should_raise_not_found_when_app_missing(self) -> None: """Test missing app raises NotFound in update path.""" @@ -562,7 +562,7 @@ class TestAppAnnotationServiceDirectManipulation: # Act & Assert with pytest.raises(NotFound): - AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1") + AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1", mock_db.session) def test_update_app_annotation_directly_should_raise_value_error_when_question_missing(self) -> None: """Test missing question raises ValueError.""" @@ -581,7 +581,7 @@ class TestAppAnnotationServiceDirectManipulation: # Act & Assert with pytest.raises(ValueError): - AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) + AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id, mock_db.session) def test_update_app_annotation_directly_should_update_annotation_and_index(self) -> None: """Test update changes fields and triggers index update.""" @@ -602,7 +602,7 @@ class TestAppAnnotationServiceDirectManipulation: mock_db.session.get.return_value = annotation # Act - result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) + result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id, mock_db.session) # Assert assert result == annotation @@ -640,7 +640,7 @@ class TestAppAnnotationServiceDirectManipulation: mock_db.session.scalars.return_value = scalars_result # Act - AppAnnotationService.delete_app_annotation(app.id, annotation.id) + AppAnnotationService.delete_app_annotation(app.id, annotation.id, mock_db.session) # Assert mock_db.session.delete.assert_any_call(annotation) @@ -667,7 +667,7 @@ class TestAppAnnotationServiceDirectManipulation: # Act & Assert with pytest.raises(NotFound): - AppAnnotationService.delete_app_annotation("app-1", "ann-1") + AppAnnotationService.delete_app_annotation("app-1", "ann-1", mock_db.session) def test_delete_app_annotation_should_raise_not_found_when_annotation_missing(self) -> None: """Test delete raises NotFound when annotation is missing.""" @@ -684,7 +684,7 @@ class TestAppAnnotationServiceDirectManipulation: # Act & Assert with pytest.raises(NotFound): - AppAnnotationService.delete_app_annotation(app.id, "ann-1") + AppAnnotationService.delete_app_annotation(app.id, "ann-1", mock_db.session) def test_delete_app_annotations_in_batch_should_return_zero_when_none_found(self) -> None: """Test batch delete returns zero when no annotations found.""" diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 5d148974f87..788a47c5c31 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -398,6 +398,7 @@ class TestAudioServiceTTS: # Act result = AudioService.transcript_tts( app_model=app, + session=MagicMock(), text="Hello world", voice="en-US-Neural", end_user="user-123", @@ -432,6 +433,7 @@ class TestAudioServiceTTS: # Act result = AudioService.transcript_tts( app_model=app, + session=MagicMock(), text="Test", ) @@ -465,6 +467,7 @@ class TestAudioServiceTTS: # Act result = AudioService.transcript_tts( app_model=app, + session=MagicMock(), text="Test", ) @@ -496,17 +499,52 @@ class TestAudioServiceTTS: mock_model_instance = MagicMock() mock_model_instance.invoke_tts.return_value = b"draft audio" mock_model_manager.get_default_model_instance.return_value = mock_model_instance + session = MagicMock() # Act result = AudioService.transcript_tts( app_model=app, + session=session, text="Draft test", is_draft=True, ) # Assert assert result == b"draft audio" - mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app) + mock_workflow_service.get_draft_workflow.assert_called_once_with(app_model=app, session=session) + + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) + def test_transcript_tts_message_id_uses_provided_session( + self, mock_model_manager_class, factory: AudioServiceTestDataFactory + ): + """Test TTS message lookup uses the injected session.""" + # Arrange + app = factory.create_app_mock(mode=AppMode.CHAT) + message_id = "00000000-0000-0000-0000-000000000001" + message = factory.create_message_mock(message_id=message_id, answer="Message answer") + session = MagicMock() + session.get.return_value = message + + mock_model_manager = mock_model_manager_class.return_value + mock_model_instance = MagicMock() + mock_model_instance.invoke_tts.return_value = b"message audio" + mock_model_manager.get_default_model_instance.return_value = mock_model_instance + + # Act + result = AudioService.transcript_tts( + app_model=app, + session=session, + message_id=message_id, + voice="message-voice", + ) + + # Assert + assert result == b"message audio" + session.get.assert_called_once_with(Message, message_id) + mock_model_instance.invoke_tts.assert_called_once_with( + content_text="Message answer", + voice="message-voice", + ) def test_transcript_tts_raises_error_when_text_missing(self, factory: AudioServiceTestDataFactory): """Test that TTS raises error when text is missing.""" @@ -515,7 +553,7 @@ class TestAudioServiceTTS: # Act & Assert with pytest.raises(ValueError, match="Text is required"): - AudioService.transcript_tts(app_model=app, text=None) + AudioService.transcript_tts(app_model=app, session=MagicMock(), text=None) @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_raises_error_when_no_voices_available( @@ -539,7 +577,7 @@ class TestAudioServiceTTS: # Act & Assert with pytest.raises(ValueError, match="Sorry, no voice available"): - AudioService.transcript_tts(app_model=app, text="Test") + AudioService.transcript_tts(app_model=app, session=MagicMock(), text="Test") class TestAudioServiceTTSVoices: diff --git a/api/tests/unit_tests/services/test_feature_service_learn_app.py b/api/tests/unit_tests/services/test_feature_service_learn_app.py new file mode 100644 index 00000000000..ed64c4d08dc --- /dev/null +++ b/api/tests/unit_tests/services/test_feature_service_learn_app.py @@ -0,0 +1,17 @@ +import pytest + +from services import feature_service as feature_service_module +from services.feature_service import FeatureService, SystemFeatureModel + + +def test_system_feature_model_defaults_enable_learn_app(): + assert SystemFeatureModel().enable_learn_app is True + + +@pytest.mark.parametrize("enabled", [True, False]) +def test_get_system_features_reads_enable_learn_app(monkeypatch: pytest.MonkeyPatch, enabled: bool): + monkeypatch.setattr(feature_service_module.dify_config, "ENABLE_LEARN_APP", enabled) + + result = FeatureService.get_system_features() + + assert result.enable_learn_app is enabled diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index d9d81d66566..5c26a9ef57d 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -673,7 +673,7 @@ def test_enqueue_resume_workflow_not_found(mocker: MockerFixture, mock_session_f assert "WorkflowRun not found" in str(excinfo.value) -def test_enqueue_resume_app_not_found(mocker, mock_session_factory, caplog): +def test_enqueue_resume_app_not_found(mocker, mock_session_factory, caplog: pytest.LogCaptureFixture): session_factory, session = mock_session_factory service = HumanInputService(session_factory) diff --git a/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py b/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py index 287d787ad70..7a8efe85f13 100644 --- a/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py +++ b/api/tests/unit_tests/services/test_knowledge_retrieval_inner_service.py @@ -74,14 +74,14 @@ def _build_source() -> Source: class TestInnerKnowledgeRetrievalService: @patch("services.knowledge_retrieval_inner_service.DatasetRetrieval") - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_db, mock_rag_cls): + def test_retrieve_maps_multiple_request_and_skips_enable_api_check(self, mock_rag_cls): request = _build_request() + mock_session = MagicMock() mock_app = MagicMock(id="app-1", tenant_id="tenant-1") dataset_1 = MagicMock(id="dataset-1", tenant_id="tenant-1", enable_api=False) dataset_2 = MagicMock(id="dataset-2", tenant_id="tenant-1", enable_api=True) - mock_db.session.scalar.return_value = mock_app - mock_db.session.scalars.return_value.all.return_value = [dataset_1, dataset_2] + mock_session.scalar.return_value = mock_app + mock_session.scalars.return_value.all.return_value = [dataset_1, dataset_2] rag = MagicMock() rag.knowledge_retrieval.return_value = [_build_source()] @@ -101,7 +101,7 @@ class TestInnerKnowledgeRetrievalService: } mock_rag_cls.return_value = rag - response = InnerKnowledgeRetrievalService().retrieve(request) + response = InnerKnowledgeRetrievalService().retrieve(request, mock_session) rag_request = rag.knowledge_retrieval.call_args.kwargs["request"] assert rag_request.tenant_id == "tenant-1" @@ -127,8 +127,7 @@ class TestInnerKnowledgeRetrievalService: assert response.usage.currency == "USD" @patch("services.knowledge_retrieval_inner_service.DatasetRetrieval") - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_maps_single_request(self, mock_db, mock_rag_cls): + def test_retrieve_maps_single_request(self, mock_rag_cls): request = _build_request( dataset_ids=["dataset-1"], retrieval={ @@ -151,8 +150,9 @@ class TestInnerKnowledgeRetrievalService: }, attachment_ids=[], ) - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") - mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] rag = MagicMock() rag.knowledge_retrieval.return_value = [] @@ -172,7 +172,7 @@ class TestInnerKnowledgeRetrievalService: } mock_rag_cls.return_value = rag - InnerKnowledgeRetrievalService().retrieve(request) + InnerKnowledgeRetrievalService().retrieve(request, mock_session) rag_request = rag.knowledge_retrieval.call_args.kwargs["request"] assert rag_request.retrieval_mode == "single" @@ -184,35 +184,35 @@ class TestInnerKnowledgeRetrievalService: assert rag_request.metadata_model_config is not None assert rag_request.metadata_model_config.provider == "openai" - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_app_missing(self, mock_db): - mock_db.session.scalar.return_value = None + def test_retrieve_raises_when_app_missing(self): + mock_session = MagicMock() + mock_session.scalar.return_value = None with pytest.raises(InnerKnowledgeRetrieveAppNotFoundError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_app_belongs_to_other_tenant(self, mock_db): - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2") + def test_retrieve_raises_when_app_belongs_to_other_tenant(self): + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-2") with pytest.raises(InnerKnowledgeRetrieveAppTenantMismatchError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_dataset_missing(self, mock_db): - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") - mock_db.session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] + def test_retrieve_raises_when_dataset_missing(self): + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_session.scalars.return_value.all.return_value = [MagicMock(id="dataset-1", tenant_id="tenant-1")] with pytest.raises(InnerKnowledgeRetrieveDatasetNotFoundError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) - @patch("services.knowledge_retrieval_inner_service.db") - def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self, mock_db): - mock_db.session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") - mock_db.session.scalars.return_value.all.return_value = [ + def test_retrieve_raises_when_dataset_belongs_to_other_tenant(self): + mock_session = MagicMock() + mock_session.scalar.return_value = MagicMock(id="app-1", tenant_id="tenant-1") + mock_session.scalars.return_value.all.return_value = [ MagicMock(id="dataset-1", tenant_id="tenant-1"), MagicMock(id="dataset-2", tenant_id="tenant-2"), ] with pytest.raises(InnerKnowledgeRetrieveDatasetTenantMismatchError): - InnerKnowledgeRetrievalService().retrieve(_build_request()) + InnerKnowledgeRetrievalService().retrieve(_build_request(), mock_session) diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index 36ea1fac1a4..6792243e9d0 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -48,13 +48,15 @@ class TestMetadataBugCompleteValidation: account = _make_account() # Should crash with TypeError with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123") + MetadataService.create_metadata(Mock(), "dataset-123", mock_metadata_args, account, "tenant-123") # Test update method as well account = _make_account() none_name = cast(str, None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.update_metadata_name("dataset-123", "metadata-456", none_name, account, "tenant-123") + MetadataService.update_metadata_name( + Mock(), "dataset-123", "metadata-456", none_name, account, "tenant-123" + ) def test_3_database_constraints_verification(self) -> None: """Test Layer 3: Verify database model has nullable=False constraints.""" @@ -97,7 +99,7 @@ class TestMetadataBugCompleteValidation: account = _make_account() with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123") + MetadataService.create_metadata(Mock(), "dataset-123", mock_metadata_args, account, "tenant-123") def test_7_end_to_end_validation_layers(self) -> None: """Test all validation layers work together correctly.""" diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index 27570a86f1a..ae93fe5ef51 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -37,7 +37,7 @@ class TestMetadataNullableBug: account = _make_account() # This should crash with TypeError when calling len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123") + MetadataService.create_metadata(Mock(), "dataset-123", mock_metadata_args, account, "tenant-123") def test_metadata_service_update_with_none_name_crashes(self) -> None: """Test that MetadataService.update_metadata_name crashes when name is None.""" @@ -45,7 +45,9 @@ class TestMetadataNullableBug: none_name = cast(str, None) # This should crash with TypeError when calling len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.update_metadata_name("dataset-123", "metadata-456", none_name, account, "tenant-123") + MetadataService.update_metadata_name( + Mock(), "dataset-123", "metadata-456", none_name, account, "tenant-123" + ) def test_api_layer_now_uses_pydantic_validation(self) -> None: """Verify that API layer relies on Pydantic validation instead of reqparse.""" diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index a2b56fe7771..d36c45d7777 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -1,3 +1,4 @@ +import logging from io import BytesIO from unittest.mock import MagicMock, patch @@ -159,7 +160,9 @@ class TestWebhookServiceUnit: assert result == "application/octet-stream" - def test_detect_binary_mimetype_handles_magic_exception(self, monkeypatch: pytest.MonkeyPatch): + def test_detect_binary_mimetype_handles_magic_exception( + self, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture + ): """Fallback MIME type should be used when python-magic raises an exception.""" try: import magic as real_magic @@ -169,12 +172,12 @@ class TestWebhookServiceUnit: fake_magic = MagicMock() fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error") monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic) + caplog.set_level(logging.DEBUG, logger="services.trigger.webhook_service") - with patch("services.trigger.webhook_service.logger", autospec=True) as mock_logger: - result = WebhookService._detect_binary_mimetype(b"binary data") + result = WebhookService._detect_binary_mimetype(b"binary data") - assert result == "application/octet-stream" - mock_logger.debug.assert_called_once() + assert result == "application/octet-stream" + assert "python-magic detection failed for octet-stream payload" in caplog.messages def test_extract_webhook_data_invalid_json(self): """Test webhook data extraction with invalid JSON.""" diff --git a/api/uv.lock b/api/uv.lock index 6e998f6f6ed..f1b1e2af01e 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -51,6 +51,7 @@ members = [ "dify-vdb-weaviate", ] overrides = [ + { name = "cryptography", specifier = ">=49.0.0,<50.0.0" }, { name = "litellm", specifier = ">=1.83.10,<2.0.0" }, { name = "pyarrow", specifier = ">=23.0.1,<24.0.0" }, ] @@ -87,7 +88,7 @@ wheels = [ [[package]] name = "aiohttp" -version = "3.13.4" +version = "3.14.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, @@ -96,27 +97,29 @@ dependencies = [ { name = "frozenlist" }, { name = "multidict" }, { name = "propcache" }, + { name = "typing-extensions" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748, upload-time = "2026-03-28T17:19:40.6Z" } +sdist = { url = "https://files.pythonhosted.org/packages/82/78/8ea7308cac6934de8c74a14f3d5f65d1c89287426688be79538d0e5c013d/aiohttp-3.14.1.tar.gz", hash = "sha256:307f2cff90a764d329e77040603fa032db89c5c24fdad50c4c15334cba744035", size = 7955794, upload-time = "2026-06-07T21:09:35.529Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158, upload-time = "2026-03-28T17:16:06.901Z" }, - { url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037, upload-time = "2026-03-28T17:16:08.82Z" }, - { url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556, upload-time = "2026-03-28T17:16:10.63Z" }, - { url = "https://files.pythonhosted.org/packages/d6/10/88ff67cd48a6ec36335b63a640abe86135791544863e0cfe1f065d6cef7a/aiohttp-3.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d738ebab9f71ee652d9dbd0211057690022201b11197f9a7324fd4dba128aa97", size = 1757314, upload-time = "2026-03-28T17:16:12.498Z" }, - { url = "https://files.pythonhosted.org/packages/8b/15/fdb90a5cf5a1f52845c276e76298c75fbbcc0ac2b4a86551906d54529965/aiohttp-3.13.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0ce692c3468fa831af7dceed52edf51ac348cebfc8d3feb935927b63bd3e8576", size = 1731819, upload-time = "2026-03-28T17:16:14.558Z" }, - { url = "https://files.pythonhosted.org/packages/ec/df/28146785a007f7820416be05d4f28cc207493efd1e8c6c1068e9bdc29198/aiohttp-3.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8e08abcfe752a454d2cb89ff0c08f2d1ecd057ae3e8cc6d84638de853530ebab", size = 1793279, upload-time = "2026-03-28T17:16:16.594Z" }, - { url = "https://files.pythonhosted.org/packages/10/47/689c743abf62ea7a77774d5722f220e2c912a77d65d368b884d9779ef41b/aiohttp-3.13.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5977f701b3fff36367a11087f30ea73c212e686d41cd363c50c022d48b011d8d", size = 1891082, upload-time = "2026-03-28T17:16:18.71Z" }, - { url = "https://files.pythonhosted.org/packages/b0/b6/f7f4f318c7e58c23b761c9b13b9a3c9b394e0f9d5d76fbc6622fa98509f6/aiohttp-3.13.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:54203e10405c06f8b6020bd1e076ae0fe6c194adcee12a5a78af3ffa3c57025e", size = 1773938, upload-time = "2026-03-28T17:16:21.125Z" }, - { url = "https://files.pythonhosted.org/packages/aa/06/f207cb3121852c989586a6fc16ff854c4fcc8651b86c5d3bd1fc83057650/aiohttp-3.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:358a6af0145bc4dda037f13167bef3cce54b132087acc4c295c739d05d16b1c3", size = 1579548, upload-time = "2026-03-28T17:16:23.588Z" }, - { url = "https://files.pythonhosted.org/packages/6c/58/e1289661a32161e24c1fe479711d783067210d266842523752869cc1d9c2/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:898ea1850656d7d61832ef06aa9846ab3ddb1621b74f46de78fbc5e1a586ba83", size = 1714669, upload-time = "2026-03-28T17:16:25.713Z" }, - { url = "https://files.pythonhosted.org/packages/96/0a/3e86d039438a74a86e6a948a9119b22540bae037d6ba317a042ae3c22711/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7bc30cceb710cf6a44e9617e43eebb6e3e43ad855a34da7b4b6a73537d8a6763", size = 1754175, upload-time = "2026-03-28T17:16:28.18Z" }, - { url = "https://files.pythonhosted.org/packages/f4/30/e717fc5df83133ba467a560b6d8ef20197037b4bb5d7075b90037de1018e/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4a31c0c587a8a038f19a4c7e60654a6c899c9de9174593a13e7cc6e15ff271f9", size = 1762049, upload-time = "2026-03-28T17:16:30.941Z" }, - { url = "https://files.pythonhosted.org/packages/e4/28/8f7a2d4492e336e40005151bdd94baf344880a4707573378579f833a64c1/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:2062f675f3fe6e06d6113eb74a157fb9df58953ffed0cdb4182554b116545758", size = 1570861, upload-time = "2026-03-28T17:16:32.953Z" }, - { url = "https://files.pythonhosted.org/packages/78/45/12e1a3d0645968b1c38de4b23fdf270b8637735ea057d4f84482ff918ad9/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d1ba8afb847ff80626d5e408c1fdc99f942acc877d0702fe137015903a220a9", size = 1790003, upload-time = "2026-03-28T17:16:35.468Z" }, - { url = "https://files.pythonhosted.org/packages/eb/0f/60374e18d590de16dcb39d6ff62f39c096c1b958e6f37727b5870026ea30/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b08149419994cdd4d5eecf7fd4bc5986b5a9380285bcd01ab4c0d6bfca47b79d", size = 1737289, upload-time = "2026-03-28T17:16:38.187Z" }, - { url = "https://files.pythonhosted.org/packages/02/bf/535e58d886cfbc40a8b0013c974afad24ef7632d645bca0b678b70033a60/aiohttp-3.13.4-cp312-cp312-win32.whl", hash = "sha256:fc432f6a2c4f720180959bc19aa37259651c1a4ed8af8afc84dd41c60f15f791", size = 434185, upload-time = "2026-03-28T17:16:40.735Z" }, - { url = "https://files.pythonhosted.org/packages/1e/1a/d92e3325134ebfff6f4069f270d3aac770d63320bd1fcd0eca023e74d9a8/aiohttp-3.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:6148c9ae97a3e8bff9a1fc9c757fa164116f86c100468339730e717590a3fb77", size = 461285, upload-time = "2026-03-28T17:16:42.713Z" }, + { url = "https://files.pythonhosted.org/packages/1d/21/151624b51cd92553d95424daf4bf19f19ce9be9002d19253e7e7ce67197b/aiohttp-3.14.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d35143e27778b4bb0fb189562d7f275bff79c62ab8e98459717c0ea617ff2480", size = 757402, upload-time = "2026-06-07T21:06:40.311Z" }, + { url = "https://files.pythonhosted.org/packages/c2/82/280619e0bd7bf2454987e19282616e84762255dd9c8468f62382e8c191f1/aiohttp-3.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bcfb80a2cc36fba2534e5e5b5264dc7ae6fcd9bf15256da3e53d2f499e6fa29d", size = 512310, upload-time = "2026-06-07T21:06:42.207Z" }, + { url = "https://files.pythonhosted.org/packages/55/b2/2aac325583aaa1353045f96dffa586d8a34e8322e14a7ba49cffeb103ab4/aiohttp-3.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:27fd7c91e51729b4f7e1577865fa6d34c9adccbc39aabe9000285b48af9f0ec2", size = 512448, upload-time = "2026-06-07T21:06:43.813Z" }, + { url = "https://files.pythonhosted.org/packages/8a/72/a60607cb849faa8af8a356c9329ea2eb6f395d49e82cc82ccba1fd8deb8f/aiohttp-3.14.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:64c567bf9eaf664280116a8688f63016e6b32db2505908e2bdaca1b6438142f2", size = 1766854, upload-time = "2026-06-07T21:06:45.391Z" }, + { url = "https://files.pythonhosted.org/packages/b5/d3/d9fe1c9ec7557ab4d0d82bebaa728c6418f0b93295ec2f4ab015f7710cc7/aiohttp-3.14.1-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:f5e6ff2bdbb8f4cd3fbe41f99e25bbcd58e3bf9f13d3dd31a11e7917251cc77a", size = 1740884, upload-time = "2026-06-07T21:06:47.413Z" }, + { url = "https://files.pythonhosted.org/packages/c1/dc/f2cecfaf9337ba3e63f181500814ff502aa3d00d9c7ec93a9d23d10a27b2/aiohttp-3.14.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2f73e01dc37122325caf079982621262f96d74823c179038a82fddfc50359264", size = 1810034, upload-time = "2026-06-07T21:06:50.165Z" }, + { url = "https://files.pythonhosted.org/packages/66/d7/2ff65c5e65c0d7476daf7e15c032e0805e36811185b9623e3238ad6c763e/aiohttp-3.14.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:bb2c0c80d431c0d03f2c7dbf125150fedd4f0de17366a7ca33f7ccb822391842", size = 1904054, upload-time = "2026-06-07T21:06:52.035Z" }, + { url = "https://files.pythonhosted.org/packages/20/9c/d445818389df371f56d141d881153ba23183c4735a03f7356ffb43f7757d/aiohttp-3.14.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3e6fc1a85fa7194a1a7d19f44e8609180f4a8eb5fa4c7ed8b4355f080fad235c", size = 1790278, upload-time = "2026-06-07T21:06:54.049Z" }, + { url = "https://files.pythonhosted.org/packages/4d/aa/bf04cb4d865fc6101c2229a294ad744973b72e513fdc5a6b791e6983d72a/aiohttp-3.14.1-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:686b6c0d3911ec387b444ddf5dc62fb7f7c0a7d5186a7861626496a5ab4aff95", size = 1591795, upload-time = "2026-06-07T21:06:55.911Z" }, + { url = "https://files.pythonhosted.org/packages/dc/b4/4dac0038960427ba832f6609dfb4ea5437d7fd80c72001b9e48f834f428b/aiohttp-3.14.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c6fa4dc7ad6f8109c70bb1499e589f76b0b792baf39f9b017eb92c8a81d0a199", size = 1728397, upload-time = "2026-06-07T21:06:57.777Z" }, + { url = "https://files.pythonhosted.org/packages/2b/f9/7cd4e8ad7aa3b75f17d56bb5498dd604a93d4e6eece822ba0568c413fff0/aiohttp-3.14.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:87a5eea1b2a5e21e1ebdbb33ad4165359189327e63fc4e4894693e7f821ac817", size = 1766504, upload-time = "2026-06-07T21:07:00.009Z" }, + { url = "https://files.pythonhosted.org/packages/f9/df/fc01d9fcad0f73fed3f3d361f1f94f975947b50dff82919f6dc2bf4316cc/aiohttp-3.14.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:1c1421eb01d4fd608d88cc8290211d177a58532b55ad94076fb349c5bf467f0a", size = 1777806, upload-time = "2026-06-07T21:07:02.064Z" }, + { url = "https://files.pythonhosted.org/packages/41/09/47e2d090bddcc8fb4ccb4c314aadc32d7c5d9bb55f50f6ad1c92fc15d501/aiohttp-3.14.1-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:34b257ec41345c1e8f2df68fa908a7952f5de932723871eb633ecbbff396c9a4", size = 1580707, upload-time = "2026-06-07T21:07:03.942Z" }, + { url = "https://files.pythonhosted.org/packages/3d/36/f1a4ce904ae0b6930cfe9afc96d0896f7ec1a620c400405d63783bb95a9c/aiohttp-3.14.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:de538791a80e5d862addbc183f70f0158ac9b9bb872bb147f1fd2a683691e087", size = 1798121, upload-time = "2026-06-07T21:07:05.987Z" }, + { url = "https://files.pythonhosted.org/packages/70/0a/e0075ce9ca0279ee1d4f0c0b85f54fea02ebc83c3007651a72bece658fec/aiohttp-3.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6f71173be42d3241d428f760122febb748de0623f44308a6f120d0dd9ec572e3", size = 1767580, upload-time = "2026-06-07T21:07:07.873Z" }, + { url = "https://files.pythonhosted.org/packages/3e/61/a0c0a8f327a9c52095cdd8e312391b00d3ed64ab6c72bb5c33d8ec251cf7/aiohttp-3.14.1-cp312-cp312-win32.whl", hash = "sha256:ec8dc383ee57ea3e883477dcca3f11b65d58199f1080acaf4cd6ad9a99698be4", size = 452771, upload-time = "2026-06-07T21:07:09.669Z" }, + { url = "https://files.pythonhosted.org/packages/df/d9/ea367c75f16ac9c6cdc8febb25e8318fa21a2b1bc8d6514d4b2d890bface/aiohttp-3.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2aa92c87868cd13674989f9ee83e5f9f7ea4237589b728048e1f0c8f6caa3271", size = 479873, upload-time = "2026-06-07T21:07:11.538Z" }, + { url = "https://files.pythonhosted.org/packages/03/64/8d96784a7851156db8a4c6c3f6f91042fdf39fb15a4cc38c8b3c14833c45/aiohttp-3.14.1-cp312-cp312-win_arm64.whl", hash = "sha256:2c840c90759922cb5e6dda94596e079a30fb5a5ba548e7e0dc00574703940847", size = 448073, upload-time = "2026-06-07T21:07:13.637Z" }, ] [[package]] @@ -556,14 +559,14 @@ wheels = [ [[package]] name = "bleach" -version = "6.3.0" +version = "6.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "webencodings" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/07/18/3c8523962314be6bf4c8989c79ad9531c825210dd13a8669f6b84336e8bd/bleach-6.3.0.tar.gz", hash = "sha256:6f3b91b1c0a02bb9a78b5a454c92506aa0fdf197e1d5e114d2e00c6f64306d22", size = 203533, upload-time = "2025-10-27T17:57:39.211Z" } +sdist = { url = "https://files.pythonhosted.org/packages/48/3c/e12ac860709702bd5ebeb9b56a4fe334f1001246ee1b8f2b7ee28912df7d/bleach-6.4.0.tar.gz", hash = "sha256:4202482733d85cedd04e59fcb2f89f4e4c7c385a78d3c3c23c30446843a37452", size = 204857, upload-time = "2026-06-05T13:01:13.734Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cd/3a/577b549de0cc09d95f11087ee63c739bba856cd3952697eec4c4bb91350a/bleach-6.3.0-py3-none-any.whl", hash = "sha256:fe10ec77c93ddf3d13a73b035abaac7a9f5e436513864ccdad516693213c65d6", size = 164437, upload-time = "2025-10-27T17:57:37.538Z" }, + { url = "https://files.pythonhosted.org/packages/58/9d/40b6267367182187139a4000b82a3b287d84d745bccd808e75d916920e9d/bleach-6.4.0-py3-none-any.whl", hash = "sha256:4b6b6a54fff2e69a3dde9d21cc6301220bee3c3cb792187d11403fd795031081", size = 165109, upload-time = "2026-06-05T13:01:12.504Z" }, ] [[package]] @@ -1141,41 +1144,39 @@ wheels = [ [[package]] name = "cryptography" -version = "46.0.7" +version = "49.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/47/93/ac8f3d5ff04d54bc814e961a43ae5b0b146154c89c61b47bb07557679b18/cryptography-46.0.7.tar.gz", hash = "sha256:e4cfd68c5f3e0bfdad0d38e023239b96a2fe84146481852dffbcca442c245aa5", size = 750652, upload-time = "2026-04-08T01:57:54.692Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1f/99/d1c90d6041656cc6ee229dc99cd67fd0cd5aec3c5f7d72fffc27cc750054/cryptography-49.0.0.tar.gz", hash = "sha256:f89660a348f4f78a92366240a61404e337586ef7f5909a2fef59ca88ef505493", size = 854345, upload-time = "2026-06-12T20:02:30.512Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/5d/4a8f770695d73be252331e60e526291e3df0c9b27556a90a6b47bccca4c2/cryptography-46.0.7-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:ea42cbe97209df307fdc3b155f1b6fa2577c0defa8f1f7d3be7d31d189108ad4", size = 7179869, upload-time = "2026-04-08T01:56:17.157Z" }, - { url = "https://files.pythonhosted.org/packages/5f/45/6d80dc379b0bbc1f9d1e429f42e4cb9e1d319c7a8201beffd967c516ea01/cryptography-46.0.7-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b36a4695e29fe69215d75960b22577197aca3f7a25b9cf9d165dcfe9d80bc325", size = 4275492, upload-time = "2026-04-08T01:56:19.36Z" }, - { url = "https://files.pythonhosted.org/packages/4a/9a/1765afe9f572e239c3469f2cb429f3ba7b31878c893b246b4b2994ffe2fe/cryptography-46.0.7-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ad9ef796328c5e3c4ceed237a183f5d41d21150f972455a9d926593a1dcb308", size = 4426670, upload-time = "2026-04-08T01:56:21.415Z" }, - { url = "https://files.pythonhosted.org/packages/8f/3e/af9246aaf23cd4ee060699adab1e47ced3f5f7e7a8ffdd339f817b446462/cryptography-46.0.7-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:73510b83623e080a2c35c62c15298096e2a5dc8d51c3b4e1740211839d0dea77", size = 4280275, upload-time = "2026-04-08T01:56:23.539Z" }, - { url = "https://files.pythonhosted.org/packages/0f/54/6bbbfc5efe86f9d71041827b793c24811a017c6ac0fd12883e4caa86b8ed/cryptography-46.0.7-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cbd5fb06b62bd0721e1170273d3f4d5a277044c47ca27ee257025146c34cbdd1", size = 4928402, upload-time = "2026-04-08T01:56:25.624Z" }, - { url = "https://files.pythonhosted.org/packages/2d/cf/054b9d8220f81509939599c8bdbc0c408dbd2bdd41688616a20731371fe0/cryptography-46.0.7-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:420b1e4109cc95f0e5700eed79908cef9268265c773d3a66f7af1eef53d409ef", size = 4459985, upload-time = "2026-04-08T01:56:27.309Z" }, - { url = "https://files.pythonhosted.org/packages/f9/46/4e4e9c6040fb01c7467d47217d2f882daddeb8828f7df800cb806d8a2288/cryptography-46.0.7-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:24402210aa54baae71d99441d15bb5a1919c195398a87b563df84468160a65de", size = 3990652, upload-time = "2026-04-08T01:56:29.095Z" }, - { url = "https://files.pythonhosted.org/packages/36/5f/313586c3be5a2fbe87e4c9a254207b860155a8e1f3cca99f9910008e7d08/cryptography-46.0.7-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8a469028a86f12eb7d2fe97162d0634026d92a21f3ae0ac87ed1c4a447886c83", size = 4279805, upload-time = "2026-04-08T01:56:30.928Z" }, - { url = "https://files.pythonhosted.org/packages/69/33/60dfc4595f334a2082749673386a4d05e4f0cf4df8248e63b2c3437585f2/cryptography-46.0.7-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:9694078c5d44c157ef3162e3bf3946510b857df5a3955458381d1c7cfc143ddb", size = 4892883, upload-time = "2026-04-08T01:56:32.614Z" }, - { url = "https://files.pythonhosted.org/packages/c7/0b/333ddab4270c4f5b972f980adef4faa66951a4aaf646ca067af597f15563/cryptography-46.0.7-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:42a1e5f98abb6391717978baf9f90dc28a743b7d9be7f0751a6f56a75d14065b", size = 4459756, upload-time = "2026-04-08T01:56:34.306Z" }, - { url = "https://files.pythonhosted.org/packages/d2/14/633913398b43b75f1234834170947957c6b623d1701ffc7a9600da907e89/cryptography-46.0.7-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91bbcb08347344f810cbe49065914fe048949648f6bd5c2519f34619142bbe85", size = 4410244, upload-time = "2026-04-08T01:56:35.977Z" }, - { url = "https://files.pythonhosted.org/packages/10/f2/19ceb3b3dc14009373432af0c13f46aa08e3ce334ec6eff13492e1812ccd/cryptography-46.0.7-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5d1c02a14ceb9148cc7816249f64f623fbfee39e8c03b3650d842ad3f34d637e", size = 4674868, upload-time = "2026-04-08T01:56:38.034Z" }, - { url = "https://files.pythonhosted.org/packages/1a/bb/a5c213c19ee94b15dfccc48f363738633a493812687f5567addbcbba9f6f/cryptography-46.0.7-cp311-abi3-win32.whl", hash = "sha256:d23c8ca48e44ee015cd0a54aeccdf9f09004eba9fc96f38c911011d9ff1bd457", size = 3026504, upload-time = "2026-04-08T01:56:39.666Z" }, - { url = "https://files.pythonhosted.org/packages/2b/02/7788f9fefa1d060ca68717c3901ae7fffa21ee087a90b7f23c7a603c32ae/cryptography-46.0.7-cp311-abi3-win_amd64.whl", hash = "sha256:397655da831414d165029da9bc483bed2fe0e75dde6a1523ec2fe63f3c46046b", size = 3488363, upload-time = "2026-04-08T01:56:41.893Z" }, - { url = "https://files.pythonhosted.org/packages/a7/7f/cd42fc3614386bc0c12f0cb3c4ae1fc2bbca5c9662dfed031514911d513d/cryptography-46.0.7-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:462ad5cb1c148a22b2e3bcc5ad52504dff325d17daf5df8d88c17dda1f75f2a4", size = 7165618, upload-time = "2026-04-08T01:57:10.645Z" }, - { url = "https://files.pythonhosted.org/packages/a5/d0/36a49f0262d2319139d2829f773f1b97ef8aef7f97e6e5bd21455e5a8fb5/cryptography-46.0.7-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:84d4cced91f0f159a7ddacad249cc077e63195c36aac40b4150e7a57e84fffe7", size = 4270628, upload-time = "2026-04-08T01:57:12.885Z" }, - { url = "https://files.pythonhosted.org/packages/8a/6c/1a42450f464dda6ffbe578a911f773e54dd48c10f9895a23a7e88b3e7db5/cryptography-46.0.7-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:128c5edfe5e5938b86b03941e94fac9ee793a94452ad1365c9fc3f4f62216832", size = 4415405, upload-time = "2026-04-08T01:57:14.923Z" }, - { url = "https://files.pythonhosted.org/packages/9a/92/4ed714dbe93a066dc1f4b4581a464d2d7dbec9046f7c8b7016f5286329e2/cryptography-46.0.7-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5e51be372b26ef4ba3de3c167cd3d1022934bc838ae9eaad7e644986d2a3d163", size = 4272715, upload-time = "2026-04-08T01:57:16.638Z" }, - { url = "https://files.pythonhosted.org/packages/b7/e6/a26b84096eddd51494bba19111f8fffe976f6a09f132706f8f1bf03f51f7/cryptography-46.0.7-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:cdf1a610ef82abb396451862739e3fc93b071c844399e15b90726ef7470eeaf2", size = 4918400, upload-time = "2026-04-08T01:57:19.021Z" }, - { url = "https://files.pythonhosted.org/packages/c7/08/ffd537b605568a148543ac3c2b239708ae0bd635064bab41359252ef88ed/cryptography-46.0.7-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1d25aee46d0c6f1a501adcddb2d2fee4b979381346a78558ed13e50aa8a59067", size = 4450634, upload-time = "2026-04-08T01:57:21.185Z" }, - { url = "https://files.pythonhosted.org/packages/16/01/0cd51dd86ab5b9befe0d031e276510491976c3a80e9f6e31810cce46c4ad/cryptography-46.0.7-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:cdfbe22376065ffcf8be74dc9a909f032df19bc58a699456a21712d6e5eabfd0", size = 3985233, upload-time = "2026-04-08T01:57:22.862Z" }, - { url = "https://files.pythonhosted.org/packages/92/49/819d6ed3a7d9349c2939f81b500a738cb733ab62fbecdbc1e38e83d45e12/cryptography-46.0.7-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:abad9dac36cbf55de6eb49badd4016806b3165d396f64925bf2999bcb67837ba", size = 4271955, upload-time = "2026-04-08T01:57:24.814Z" }, - { url = "https://files.pythonhosted.org/packages/80/07/ad9b3c56ebb95ed2473d46df0847357e01583f4c52a85754d1a55e29e4d0/cryptography-46.0.7-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:935ce7e3cfdb53e3536119a542b839bb94ec1ad081013e9ab9b7cfd478b05006", size = 4879888, upload-time = "2026-04-08T01:57:26.88Z" }, - { url = "https://files.pythonhosted.org/packages/b8/c7/201d3d58f30c4c2bdbe9b03844c291feb77c20511cc3586daf7edc12a47b/cryptography-46.0.7-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:35719dc79d4730d30f1c2b6474bd6acda36ae2dfae1e3c16f2051f215df33ce0", size = 4449961, upload-time = "2026-04-08T01:57:29.068Z" }, - { url = "https://files.pythonhosted.org/packages/a5/ef/649750cbf96f3033c3c976e112265c33906f8e462291a33d77f90356548c/cryptography-46.0.7-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:7bbc6ccf49d05ac8f7d7b5e2e2c33830d4fe2061def88210a126d130d7f71a85", size = 4401696, upload-time = "2026-04-08T01:57:31.029Z" }, - { url = "https://files.pythonhosted.org/packages/41/52/a8908dcb1a389a459a29008c29966c1d552588d4ae6d43f3a1a4512e0ebe/cryptography-46.0.7-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a1529d614f44b863a7b480c6d000fe93b59acee9c82ffa027cfadc77521a9f5e", size = 4664256, upload-time = "2026-04-08T01:57:33.144Z" }, - { url = "https://files.pythonhosted.org/packages/4b/fa/f0ab06238e899cc3fb332623f337a7364f36f4bb3f2534c2bb95a35b132c/cryptography-46.0.7-cp38-abi3-win32.whl", hash = "sha256:f247c8c1a1fb45e12586afbb436ef21ff1e80670b2861a90353d9b025583d246", size = 3013001, upload-time = "2026-04-08T01:57:34.933Z" }, - { url = "https://files.pythonhosted.org/packages/d2/f1/00ce3bde3ca542d1acd8f8cfa38e446840945aa6363f9b74746394b14127/cryptography-46.0.7-cp38-abi3-win_amd64.whl", hash = "sha256:506c4ff91eff4f82bdac7633318a526b1d1309fc07ca76a3ad182cb5b686d6d3", size = 3472985, upload-time = "2026-04-08T01:57:36.714Z" }, + { url = "https://files.pythonhosted.org/packages/9b/22/adf66990e63584a68dfb50c24f48a125c07b1699899381c8151e63ed458c/cryptography-49.0.0-cp311-abi3-macosx_11_0_arm64.whl", hash = "sha256:966fe0e9c67490071f14c0d2b1cb2dfb3023c5ce39457343931415f08382f2db", size = 4032100, upload-time = "2026-06-12T20:02:32.143Z" }, + { url = "https://files.pythonhosted.org/packages/09/41/3797cfaf69cae04a13ee78ebd83f0678d9c02b4779d21ce24445326f1a69/cryptography-49.0.0-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:36d1709f992593689b45bda411498d62c6e365f2ca00b84657d4dadd24de16db", size = 4692978, upload-time = "2026-06-12T20:01:21.305Z" }, + { url = "https://files.pythonhosted.org/packages/e6/8b/43011f7ebe515a8aa20d61f290a326cd890c2e738e16e59eaff8d9c3a412/cryptography-49.0.0-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0e959b578856a3924bc0cbb710fc12c387b9412a951389f3ca61704a9e25f325", size = 4716422, upload-time = "2026-06-12T20:01:48.566Z" }, + { url = "https://files.pythonhosted.org/packages/4a/91/01ce7303a4579e6d3a6abef01bd322848e9ea7a219adcabc5048b9033571/cryptography-49.0.0-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:53ecee2e23f7169b6117e99fc8a944e5e50f79e69758a83b52a00cb98ab2b2d2", size = 4700503, upload-time = "2026-06-12T20:02:47.091Z" }, + { url = "https://files.pythonhosted.org/packages/62/99/a2c95cf8293f07491e9e27c20cc4dcd18176d944e674679adeb1d0173fd6/cryptography-49.0.0-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:2eda353d8a27bcbcaa4cbed18994a74ab4d19a2ca897db188ea269ab9b71419b", size = 5309779, upload-time = "2026-06-12T20:02:08.987Z" }, + { url = "https://files.pythonhosted.org/packages/20/2c/0622f20ff02b2ef32558733443805dc82fd4c275be01b2d19d14676f3a1b/cryptography-49.0.0-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:2afe9051da7ae7bd5905da5a949280c7d2bb75682e188f650a9d0f2756b834c6", size = 4749683, upload-time = "2026-06-12T20:02:03.335Z" }, + { url = "https://files.pythonhosted.org/packages/a3/5b/c5246635d5fd3b64e0d45ae10e99fd32fe9676a79915ccfe5a61ba9af1a5/cryptography-49.0.0-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:0b82e28ee398a386f0807bba7884d30f25218855690f45115831bcce5d90822c", size = 4337874, upload-time = "2026-06-12T20:02:54.323Z" }, + { url = "https://files.pythonhosted.org/packages/6d/88/05563c7fe2e914e87d1a536d06fe83e66b4e1d95cb593e05aea375531da8/cryptography-49.0.0-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:ccac2bfebc306b862133e3bb71f3f6ee8bb525240089b2d952e4144b3a6d5da7", size = 4700283, upload-time = "2026-06-12T20:01:34.822Z" }, + { url = "https://files.pythonhosted.org/packages/c4/b6/d7696e4e890d6ae1469935164c9e5215c557671cb78d6e3f458ccceaa632/cryptography-49.0.0-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:d0527ce944105f257f605a827d6ebead966c752038b6e8656abb9c5edee6fc68", size = 5265844, upload-time = "2026-06-12T20:01:24.09Z" }, + { url = "https://files.pythonhosted.org/packages/a9/3c/f3ad17eecc1a57b0ba236dc01f90e783c51f4a2f35f64777cc4f47a184b2/cryptography-49.0.0-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:cbc77da8c523d5abd028635ba850a6966fcee2c82e2bf65a41d1d8afe0f98be9", size = 4749290, upload-time = "2026-06-12T20:01:30.848Z" }, + { url = "https://files.pythonhosted.org/packages/4f/01/339573cf1023163a400b0b5d16f6d507de413b9f60be6fd1b77feeaf6737/cryptography-49.0.0-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b87e65d263b3e5d3bb92a57e2a6638e2f31110fa7aa890c7b2dbba42248d0a3f", size = 4834612, upload-time = "2026-06-12T20:01:29.246Z" }, + { url = "https://files.pythonhosted.org/packages/71/fd/577302e213a1be9468f92d1afef66fcf1ef83d516819d9992ca547f592bd/cryptography-49.0.0-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:66ec79c3904820572d7e987abdf304281f141d37ad9a489b8e97066e7b9b6459", size = 4980804, upload-time = "2026-06-12T20:01:42.853Z" }, + { url = "https://files.pythonhosted.org/packages/1f/09/f42b1d190c5ba75f72062a387f8030d1d75f6ab035788f1d9c4b01de6525/cryptography-49.0.0-cp311-abi3-win_amd64.whl", hash = "sha256:e5dfc1e64de5677cec922ffa8da89c546d0415bf6efdf081842e5d44c84e1f0e", size = 3810026, upload-time = "2026-06-12T20:02:39.262Z" }, + { url = "https://files.pythonhosted.org/packages/19/2a/5bb823f5bedcf80718cea7fbc95ec5515cca3769633c4b01a32be7f30e7c/cryptography-49.0.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:ec5e529fb80935c94fe7b729f9972b50e351a0e6b50aa294fd5cabb109fcc29a", size = 4025947, upload-time = "2026-06-12T20:01:25.745Z" }, + { url = "https://files.pythonhosted.org/packages/3d/df/40577043ca124e17012f408ddddaeb213b856336ac82ddb3bc915f39e29f/cryptography-49.0.0-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f78ff2c9ed8dc2d036b0f4d640e22522213d047c1b14e61205a7e55c80a494d4", size = 4692429, upload-time = "2026-06-12T20:01:53.628Z" }, + { url = "https://files.pythonhosted.org/packages/2c/99/2d13299eb3dd27b02dcfaafcc91d6b5cb3329f7cbd6d8f51921acd566c1a/cryptography-49.0.0-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:35b151772baff2c74cba7fa290ceaff4c3b11c0c881eb93eb5dbc05a7cfbba18", size = 4700968, upload-time = "2026-06-12T20:02:45.383Z" }, + { url = "https://files.pythonhosted.org/packages/a5/4d/9c0cd02f95e2602dd5e563da149ee0830abef3537be8b34dc56281ebe27a/cryptography-49.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:0f21641cf4b30fca7aee061ced0ec7ad7b073518088b7c9969a297c0ae796c69", size = 4697758, upload-time = "2026-06-12T20:01:41.13Z" }, + { url = "https://files.pythonhosted.org/packages/24/01/186c825898477d77e2324d5360fefe622ff1d8d1963ec0554e2cada8ec77/cryptography-49.0.0-cp39-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:9e82dcc8e56052715fb18b2429e3bca4823b1629136a2084fc45a9a5cecb9b64", size = 5298863, upload-time = "2026-06-12T20:02:24.579Z" }, + { url = "https://files.pythonhosted.org/packages/b8/7b/62cbbab75d0659865bf0273790031544a0b16c8072d258f9428dcd8190dc/cryptography-49.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:6f2debedf9ca60cf1d5bd466475638af5130f89965605cd818484d19987d3a21", size = 4735983, upload-time = "2026-06-12T20:01:50.14Z" }, + { url = "https://files.pythonhosted.org/packages/6c/72/3e798c064bc39e471008075d0f9bc9daf77a80879c092e4a8e170c585ed4/cryptography-49.0.0-cp39-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:8c25ceb16df5b9435f3f6a9829204985b0e0cbee3b48aacd432c7d2c850b44d9", size = 4334173, upload-time = "2026-06-12T20:01:44.743Z" }, + { url = "https://files.pythonhosted.org/packages/f0/ee/6fca21d1ac73e06f8bef71940abfd4d2f6472b4bca284d770f32bd4086f6/cryptography-49.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:28d8b15e6275f12c8a207dc309dfa957903c927d08d0cc937ee3f63f200693cc", size = 4697298, upload-time = "2026-06-12T20:02:20.918Z" }, + { url = "https://files.pythonhosted.org/packages/67/d0/a5fcd3515f0bae49a7b6d0413cc1bdccdcc1fc0047037a0d480642cdc5d6/cryptography-49.0.0-cp39-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:6fc361c34fb6aac015ce19435876635e5c6d21db31998b0920f675f131e043b8", size = 5254338, upload-time = "2026-06-12T20:02:22.737Z" }, + { url = "https://files.pythonhosted.org/packages/a0/84/84fe36f19caf857d61cb7fc9c63035a47ffabd84ea12d1d393148efa3615/cryptography-49.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:2400ef9c9e2299a25614eb1dea3db54a69b1349efd043bfac9c67630d136df36", size = 4735650, upload-time = "2026-06-12T20:02:41.389Z" }, + { url = "https://files.pythonhosted.org/packages/6c/a0/db537264e234f7273a73ec020873d6d6b39dfd8a53db78b550ca8320440e/cryptography-49.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:67e1d20ad9ef3a563c59ef22e7a8a0b8210bd26604369ea4a30a7c66aefe504e", size = 4834820, upload-time = "2026-06-12T20:01:51.847Z" }, + { url = "https://files.pythonhosted.org/packages/93/77/8df9eb486495979bccecd1062e2eaf435250e84437040295b57d09048b0b/cryptography-49.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:42b0684e0e40cf26122427802486f6d93aea593612603a94fbf260c7eb1e9c1b", size = 4967968, upload-time = "2026-06-12T20:02:12.524Z" }, + { url = "https://files.pythonhosted.org/packages/c2/e6/f60198ea8d9dfa15fff9ed4ca02ce362f6eadd9ba757dcc50634c4257b63/cryptography-49.0.0-cp39-abi3-win_amd64.whl", hash = "sha256:026ac7423e6fa66872d3bf889be5974507da3944f866f704fa200eadacd00001", size = 3785547, upload-time = "2026-06-12T20:02:26.847Z" }, ] [[package]] @@ -1618,7 +1619,7 @@ vdb-xinference = [ requires-dist = [ { name = "aliyun-log-python-sdk", specifier = "==0.9.44" }, { name = "azure-identity", specifier = ">=1.25.3,<2.0.0" }, - { name = "bleach", specifier = ">=6.3.0,<7.0.0" }, + { name = "bleach", specifier = ">=6.4.0,<7.0.0" }, { name = "boto3", specifier = ">=1.43.24,<2.0.0" }, { name = "celery", specifier = ">=5.6.3,<6.0.0" }, { name = "croniter", specifier = ">=6.2.2,<7.0.0" }, @@ -1636,7 +1637,7 @@ requires-dist = [ { name = "gmpy2", specifier = ">=2.3.0,<3.0.0" }, { name = "google-api-python-client", specifier = ">=2.196.0,<3.0.0" }, { name = "google-cloud-aiplatform", specifier = ">=1.151.0,<2.0.0" }, - { name = "graphon", specifier = "==0.5.2" }, + { name = "graphon", specifier = "==0.5.3" }, { name = "gunicorn", specifier = ">=26.0.0,<27.0.0" }, { name = "httpx", extras = ["socks"], specifier = "==0.28.1" }, { name = "httpx-sse", specifier = "==0.4.3" }, @@ -1869,7 +1870,7 @@ dependencies = [ ] [package.metadata] -requires-dist = [{ name = "langsmith", specifier = "==0.8.5" }] +requires-dist = [{ name = "langsmith", specifier = "==0.8.18" }] [[package]] name = "dify-trace-mlflow" @@ -2987,7 +2988,7 @@ httpx = [ [[package]] name = "graphon" -version = "0.5.2" +version = "0.5.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "charset-normalizer" }, @@ -3008,9 +3009,9 @@ dependencies = [ { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, { name = "webvtt-py" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c2/16/f183da187414c335be67f52f6a1b7c2a33bf0b1d5090eda7e6c92d42d94a/graphon-0.5.2.tar.gz", hash = "sha256:d66a9edcd883766bd50e94f84a691c92ce536ea60e721552089e83ac8e94bf68", size = 269773, upload-time = "2026-06-16T04:06:22.074Z" } +sdist = { url = "https://files.pythonhosted.org/packages/50/02/75c8cc2f946c8b6debe4f71a8a0f41a69cd499073368a8735ca424c6551f/graphon-0.5.3.tar.gz", hash = "sha256:eaa87d5e664acdf14c80e38afce6bc0f14644961de7ce7b059266fe61bc30e0b", size = 271204, upload-time = "2026-06-23T08:13:32.46Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/e6/36a3981cd44e7a40a7cd7d374e26f01e02dd49410c5fbbd7df248750d5fb/graphon-0.5.2-py3-none-any.whl", hash = "sha256:11f89399e67ed1ddd2ce1c336accd9c4ad5b8fe2741f9167e6085af0b325cd14", size = 381908, upload-time = "2026-06-16T04:06:20.453Z" }, + { url = "https://files.pythonhosted.org/packages/84/fb/616f8ecbd184af57dca8380877b149198d944f4a6658cceb353ae02ace92/graphon-0.5.3-py3-none-any.whl", hash = "sha256:a7f070d1e5eef13d25b97cce6d23675b228c1d38f3c656e3dcacaa6be9ccada4", size = 383359, upload-time = "2026-06-23T08:13:31.075Z" }, ] [[package]] @@ -3663,7 +3664,7 @@ wheels = [ [[package]] name = "langsmith" -version = "0.8.5" +version = "0.8.18" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -3673,12 +3674,13 @@ dependencies = [ { name = "requests" }, { name = "requests-toolbelt" }, { name = "uuid-utils" }, + { name = "websockets" }, { name = "xxhash" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/17/eb/8883d1158c743d0aac350f09df7880714d27283497e8c80bb9fe3480f165/langsmith-0.8.5.tar.gz", hash = "sha256:3615243d99c12f4047f13042bdc05a373dce232d106a6511b3ca7b48c5af1c2c", size = 4462348, upload-time = "2026-05-15T21:31:41.093Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/d9/a6681aa9847bbbc5ec21abe20a5e233b94e5edcfe39624db607ac7e8ccb4/langsmith-0.8.18.tar.gz", hash = "sha256:32dde9c0e67e053e0fb738921fc8ced768af7b8fa83d7a0e3fd63597cf8776dd", size = 4526988, upload-time = "2026-06-19T13:12:17.123Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/23/85/968c88a63e32a59b3e5c68afd2fe114ce0708a125db0be1a85efc25fb2ea/langsmith-0.8.5-py3-none-any.whl", hash = "sha256:efc779f9d450dcaf9d97bc8894f4926276509d6e730e05289af9a64debce06ae", size = 399564, upload-time = "2026-05-15T21:31:39.046Z" }, + { url = "https://files.pythonhosted.org/packages/03/70/0e0cc80a3b064c8d6c8d697c3125ed86e39d5a7393ec6dc8b07cb1cf13c4/langsmith-0.8.18-py3-none-any.whl", hash = "sha256:3940183349993faef48e6c7d08e4822ee9cefd906b362d0e3c2d650314d2f282", size = 508108, upload-time = "2026-06-19T13:12:15.348Z" }, ] [[package]] @@ -3713,7 +3715,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.83.14" +version = "1.89.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3729,9 +3731,9 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/7c/c095649380adc96c8630273c1768c2ad1e74aa2ee1dd8dd05d218a60569f/litellm-1.83.14.tar.gz", hash = "sha256:24aef9b47cdc424c833e32f3727f411741c690832cd1fe4405e0077144fe09c9", size = 14836599, upload-time = "2026-04-26T03:16:10.176Z" } +sdist = { url = "https://files.pythonhosted.org/packages/56/f1/f7cfead063f2ab1877c8fb465d0d7fe300b75f081bcb73525f6d550aeb1c/litellm-1.89.3.tar.gz", hash = "sha256:8fcdb2b7a0ef3381d41adf164443842e31ef9f0cd5bcda6fc3c0bd8bc2959510", size = 14080611, upload-time = "2026-06-20T22:42:26.997Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/5c/1b5691575420135e90578543b2bf219497caa33cfd0af64cb38f30288450/litellm-1.83.14-py3-none-any.whl", hash = "sha256:92b11ba2a32cf80707ddf388d18526696c7999a21b418c5e3b6eda1243d2cfdb", size = 16457054, upload-time = "2026-04-26T03:16:05.72Z" }, + { url = "https://files.pythonhosted.org/packages/d1/f1/34d174ff1d84e459b30f971606ac9cb7078ad24cd7661e9786b25adf7def/litellm-1.89.3-py3-none-any.whl", hash = "sha256:414ef5aee504b2b3eb1b219d39f1c11902db399cbdbc06e5fb550c15d731abeb", size = 15495226, upload-time = "2026-06-20T22:42:23.156Z" }, ] [[package]] @@ -5225,16 +5227,16 @@ wheels = [ [[package]] name = "pydantic-settings" -version = "2.13.1" +version = "2.14.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "python-dotenv" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/52/6d/fffca34caecc4a3f97bda81b2098da5e8ab7efc9a66e819074a11955d87e/pydantic_settings-2.13.1.tar.gz", hash = "sha256:b4c11847b15237fb0171e1462bf540e294affb9b86db4d9aa5c01730bdbe4025", size = 223826, upload-time = "2026-02-19T13:45:08.055Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/b5/8f48e906c3e0205276e8bd8cb7512217a87b2685304d64be27cad5b3019f/pydantic_settings-2.14.2.tar.gz", hash = "sha256:c19dd64b19097f1de80184f0cc7b0272a13ae6e170cbf240a3e27e381ed14a5f", size = 237700, upload-time = "2026-06-19T13:44:56.324Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" }, + { url = "https://files.pythonhosted.org/packages/77/c1/6e422f34e569cf8e18df68d1939c81c099d2b61e4f7d9621c8a77560799c/pydantic_settings-2.14.2-py3-none-any.whl", hash = "sha256:a20c97b37910b6550d5ea50fbcc2d4187defe58cd57070b73863d069419c9440", size = 61715, upload-time = "2026-06-19T13:44:55.02Z" }, ] [[package]] @@ -5353,11 +5355,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.10.2" +version = "6.14.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7b/3f/9f2167401c2e94833ca3b69535bad89e533b5de75fefe4197a2c224baec2/pypdf-6.10.2.tar.gz", hash = "sha256:7d09ce108eff6bf67465d461b6ef352dcb8d84f7a91befc02f904455c6eea11d", size = 5315679, upload-time = "2026-04-15T16:37:36.978Z" } +sdist = { url = "https://files.pythonhosted.org/packages/03/72/7dfd5ff1c9c37de97a731701f51af091325f123d9d4270361c9c69e4431f/pypdf-6.14.2.tar.gz", hash = "sha256:7873f502fe4385e79539b21d872392dc0c4e3714327c15881cbc7fbfd1f95b25", size = 6491182, upload-time = "2026-06-23T14:18:30.859Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/d6/1d5c60cc17bbdf37c1552d9c03862fc6d32c5836732a0415b2d637edc2d0/pypdf-6.10.2-py3-none-any.whl", hash = "sha256:aa53be9826655b51c96741e5d7983ca224d898ac0a77896e64636810517624aa", size = 336308, upload-time = "2026-04-15T16:37:34.851Z" }, + { url = "https://files.pythonhosted.org/packages/49/e6/136aa8993a2ae7214e0b0ef2edaa0d2e08d1d4e4982635b08a835ff31ec8/pypdf-6.14.2-py3-none-any.whl", hash = "sha256:3f07891af76dc002657e04993ab9b4de81de29f9013b9761d0b7968bff12e946", size = 349514, upload-time = "2026-06-23T14:18:28.867Z" }, ] [[package]] @@ -6237,15 +6239,15 @@ wheels = [ [[package]] name = "starlette" -version = "1.0.1" +version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/08/a3/84e821cc54b4ab50ae6dbc6ac3800a651b65ec35f045cc73785380654057/starlette-1.0.1.tar.gz", hash = "sha256:512399c5f1de7fac99c88572212ded9ddeddef2fb32afa82d724000e88b38f4f", size = 2659596, upload-time = "2026-05-21T21:58:58.433Z" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/e3/7c1dc7381d9f8ab7d854328ebfa884e62cb3f3d8549ddfd37c7814f42afa/starlette-1.3.1.tar.gz", hash = "sha256:05d0213193f2fbaae60e2ecb593b4add4262ad4e46536b54abe36f11a71724e0", size = 2703240, upload-time = "2026-06-12T09:23:11.602Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/e1/b2df4bc09a1e51ff664c1e17018a4274b42e5e9352e4a478ea540512dc88/starlette-1.0.1-py3-none-any.whl", hash = "sha256:7c0e69b2ee1c848bd54669d908500117a3ee13de603a21427e5c6fc1adf98dcd", size = 72802, upload-time = "2026-05-21T21:58:56.551Z" }, + { url = "https://files.pythonhosted.org/packages/ec/bb/2799cc2ede3ed41131f8975621e7213dfc7ef4acbbaadfa440f32500c370/starlette-1.3.1-py3-none-any.whl", hash = "sha256:c7372aae11c3c3f26a42df7bd626cec2f47d03483d261d369516a615a53714c6", size = 73632, upload-time = "2026-06-12T09:23:10.017Z" }, ] [[package]] @@ -7053,24 +7055,29 @@ wheels = [ [[package]] name = "ujson" -version = "5.12.1" +version = "5.13.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bc/78/937198ea8708182dd1edbf0237bf255a96feab3f511691ad08b84da98e5d/ujson-5.12.1.tar.gz", hash = "sha256:5b7e96406c301a1366534479a7352ec40ec68bb327c0c119091635acd5925e35", size = 7164538, upload-time = "2026-05-05T22:05:01.354Z" } +sdist = { url = "https://files.pythonhosted.org/packages/89/7a/c8bb37c8f6f3623d60c33d15d18cd6d6655d0f9c3eb31a9969f76361b199/ujson-5.13.0.tar.gz", hash = "sha256:d62e3d7625384c08082abad81a077af587fdef2761bb14c3822f4234b8d07d75", size = 7166784, upload-time = "2026-06-14T22:36:50.209Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d7/40/dbb8e2fe6ee33769602fba203dacaa3963b6599f0d0aefdf2b8811af5f70/ujson-5.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:10f44bd08ae52ee23ca6e8b472692e5da1768af2d53ff1bad6f40b532e0bc7ee", size = 57951, upload-time = "2026-05-05T22:03:31.606Z" }, - { url = "https://files.pythonhosted.org/packages/8d/db/627472e6b4ac34148ea52e6d3d15f6f366fc21c72fe7d6c7d3729d4b3ac5/ujson-5.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6cc6ea753b7303fa5629fa9ac9257ea4b001c4d72583b2bb36ff1855a07db49f", size = 55562, upload-time = "2026-05-05T22:03:32.853Z" }, - { url = "https://files.pythonhosted.org/packages/be/59/1248c966da197ae7d2673542444a2d9a1ff7c46e3ec2a302c3caf902b922/ujson-5.12.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:289f13095764d03734adfa10107da9b530ceb64dc1b02a5f507588d978d5b7df", size = 59448, upload-time = "2026-05-05T22:03:34.143Z" }, - { url = "https://files.pythonhosted.org/packages/d5/d7/60c1ca71a09c0654c3edca1192a18fc55e6cc06107be86d7d3f2b39fb29b/ujson-5.12.1-cp312-cp312-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:427893168d074e59214b0ee058337c57f5bb80175cdd5b4799a9c931aae22022", size = 61608, upload-time = "2026-05-05T22:03:35.386Z" }, - { url = "https://files.pythonhosted.org/packages/d5/0a/c619525576219bfc50084100117481b1a732a16716a3878355570995de4e/ujson-5.12.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a7a81724d5d90a2da7155d15d8b156ce57eaed7cdd622df813f36a8e612fd4c8", size = 59113, upload-time = "2026-05-05T22:03:37.555Z" }, - { url = "https://files.pythonhosted.org/packages/18/4d/79c1674036085e8dfdb77f8d87c1fd2896e97e6affd117c5e8ecc40f0ae4/ujson-5.12.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3a6efff7dc6515416366819de4a1bc449b77107c5b48508b101fd40f7f8bec08", size = 1038914, upload-time = "2026-05-05T22:03:38.954Z" }, - { url = "https://files.pythonhosted.org/packages/94/b1/9409bba17189ee282b6314cdf0ecdcc72e3d38cd565c870c0227d0494569/ujson-5.12.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:77a71fe53427a0cf49d56eafd801d9f7e203b784b7f99cc717783fd6f6f7b732", size = 1198408, upload-time = "2026-05-05T22:03:40.943Z" }, - { url = "https://files.pythonhosted.org/packages/4b/ad/fafbce7ac59f1a10a83892d0a34add23cc06492308e1330493aab707dc20/ujson-5.12.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ea3bed53d2ea8e5642e814a9e41f3e29420a8067874ba03ace8c0462e160490c", size = 1091451, upload-time = "2026-05-05T22:03:42.739Z" }, - { url = "https://files.pythonhosted.org/packages/5a/1f/76fc9d5b1dcb9eb73ed45fd56e5114391bd30808eb1cea7f8bc5c9a64324/ujson-5.12.1-cp312-cp312-win32.whl", hash = "sha256:758e5c8fbe4e6d483041e03b307b01fb5d2f2dd4452d4d4b927ab902e188939e", size = 41049, upload-time = "2026-05-05T22:03:44.341Z" }, - { url = "https://files.pythonhosted.org/packages/35/2a/7ce3b6fda10d05b79a245db03405734b521ba3da6c377f173b018dce6d4e/ujson-5.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:f6074d3d3267ba1914c624b6e1fa3d8152648ff36b0ab77ddf83b92db488c30d", size = 45330, upload-time = "2026-05-05T22:03:45.828Z" }, - { url = "https://files.pythonhosted.org/packages/d7/66/5a37bba7a2e2ab36ae467521c4511e6593ad74c869f62ec4ba6330f3f71e/ujson-5.12.1-cp312-cp312-win_arm64.whl", hash = "sha256:7642a41520ac1b2bc25ea282b66b8da522cc43424442e6fb5e039be4d4f96530", size = 39828, upload-time = "2026-05-05T22:03:47.123Z" }, - { url = "https://files.pythonhosted.org/packages/6d/26/c9d0479236b3f5690d6a8bb45f708aabc2c91ca80d275eba24b1e9e464ab/ujson-5.12.1-graalpy312-graalpy250_312_native-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b2c419bf42ae40963fc27f70c59e24e9a97f5cf168dbce2c572f3c0ce3595912", size = 56153, upload-time = "2026-05-05T22:04:40.326Z" }, - { url = "https://files.pythonhosted.org/packages/ee/c8/785f4e132500aff2f1fd2bd4a4b86fe396a5519f830a098358c90ebb92ee/ujson-5.12.1-graalpy312-graalpy250_312_native-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0be2b4f2f547b9f0f3d902640e410e5a2fc851576cbe033c88445a23e3e7aef1", size = 57352, upload-time = "2026-05-05T22:04:42.005Z" }, - { url = "https://files.pythonhosted.org/packages/8f/13/b688a905653871b10b4ff0403c2ff562c17a0bd50be0d44324f3c85ca48f/ujson-5.12.1-graalpy312-graalpy250_312_native-win_amd64.whl", hash = "sha256:4ea0c490c702c20495e97345acfcf0c2f3153e658ef537ff111929c48b89e10a", size = 45988, upload-time = "2026-05-05T22:04:43.36Z" }, + { url = "https://files.pythonhosted.org/packages/ee/ae/b66deca15da1f7faf6952d8eddf55978482bcbfd294ed2afe2c526ea325f/ujson-5.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bf81570ac056cb058f9117b52ca5dd800bfe9381d0076d0bb30a08a54591d654", size = 56743, upload-time = "2026-06-14T22:35:28.863Z" }, + { url = "https://files.pythonhosted.org/packages/88/4f/b03bcc9eaf4621ac9008dec90918d8fb4839d611666cb99eb255696c67fe/ujson-5.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7edf16359c52ed53406e216565d83e6b98c23c3cb9a0a01673f2493f8fb15edf", size = 54390, upload-time = "2026-06-14T22:35:29.857Z" }, + { url = "https://files.pythonhosted.org/packages/77/79/f98c6c1a4ed9d92d39d5d2d133f2b6fce5da11ea50c341117aedde8011c4/ujson-5.13.0-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:24539618fb3243cfdf27dab9a850acab80798a01501e9586b61fb9ecd016a891", size = 60047, upload-time = "2026-06-14T22:35:30.857Z" }, + { url = "https://files.pythonhosted.org/packages/bc/1d/f68e14cf476d149945211142f4c20782c1f232c489e8edcc4f4b58ce4997/ujson-5.13.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fdde6341d213b29f413b5fa9fad1392d5408074c75f0900ed949e97e546fa5df", size = 53437, upload-time = "2026-06-14T22:35:31.835Z" }, + { url = "https://files.pythonhosted.org/packages/7d/1a/5718237cf4141e5be46ff371387e90b01f27774cb6f0f79ff4803a2430ca/ujson-5.13.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:229faf041ef249ee3fd57bac1cedb123d2718ab63f6ccd50eca95ea902eb0dca", size = 55057, upload-time = "2026-06-14T22:35:32.897Z" }, + { url = "https://files.pythonhosted.org/packages/f6/6f/7f55c1e9e0be87beebaed553fa186ad5f6d5d639cbaa9d49f78f2f91c3a9/ujson-5.13.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1d02f31c2f59cc6a1c2c3633b377701fc2d8e876cc01950735d7a01132ccc233", size = 58186, upload-time = "2026-06-14T22:35:34.055Z" }, + { url = "https://files.pythonhosted.org/packages/bc/c4/9a34ade542426f56a0bc042f774073d1c247ae7575363c27587788cb2b2f/ujson-5.13.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ea7204e9fa7538bfbb1396e1ee8c2bbcd3818b3633ef5bb14d4fdea52994d14d", size = 57935, upload-time = "2026-06-14T22:35:35.05Z" }, + { url = "https://files.pythonhosted.org/packages/36/06/407633f0709e168107f56368bd5a0fa8fe07acd7f1d3000710bc0bb07470/ujson-5.13.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7c5a2478a3a1fa4421f7e035b87194eea0cf44c7971a3f32ad1b42a0dfd63c03", size = 1037685, upload-time = "2026-06-14T22:35:36.022Z" }, + { url = "https://files.pythonhosted.org/packages/c3/df/eb5bd92dc1b23254fea5b2022007baff5491a7478bfcf7e9260d3a10f1ac/ujson-5.13.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b535e0970c96957e999cfe5ec89361f0e8d0bb987fb5d5144f6f495cb3ed9e19", size = 1197141, upload-time = "2026-06-14T22:35:37.38Z" }, + { url = "https://files.pythonhosted.org/packages/cb/1c/65f2ce1a0411ec9a87339db01f0d5d554a49c4248ec68ab52a1b7e14e9c4/ujson-5.13.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3d0ad1207694988498fca7e0bb28eba7564fa33261d2f9fdf66a3aaab376b803", size = 1090225, upload-time = "2026-06-14T22:35:38.95Z" }, + { url = "https://files.pythonhosted.org/packages/73/53/310aabff0704f9c7ef0d3f431ce8b8e3147c3cca25334a205615c511f65e/ujson-5.13.0-cp312-cp312-win32.whl", hash = "sha256:d6bc9fa43a49e403c68c7eb164eef0feee9dd29474a7c6e0d3b6267025371990", size = 40075, upload-time = "2026-06-14T22:35:40.44Z" }, + { url = "https://files.pythonhosted.org/packages/b5/23/d3536d8945d1bb00248d998c8dcbe678a884681ad181072daecfafe4eea6/ujson-5.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:6692d49ff970aaa5008f4a6fe06974bc91fd957bf13173f765e46d8ba44906ea", size = 41097, upload-time = "2026-06-14T22:35:41.39Z" }, + { url = "https://files.pythonhosted.org/packages/72/a1/4b147c06ee5bb14bec6e26786358c8510c4d75e28b88146a6ac7620f1f71/ujson-5.13.0-cp312-cp312-win_arm64.whl", hash = "sha256:5737ffe0740a788b0e6255f0ffb281db49305fd6e6a587be44c73d9e92b554c4", size = 38875, upload-time = "2026-06-14T22:35:42.357Z" }, + { url = "https://files.pythonhosted.org/packages/30/70/dbdd277d64bd3a149532567ceb082fe26f4ead58c39e0a97566ccbdf14a3/ujson-5.13.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:3e074a1f7778d58aa3b3056bab7b6251aabb3f381808018ca2b7fb8dbdeef7ab", size = 58393, upload-time = "2026-06-14T22:36:28.702Z" }, + { url = "https://files.pythonhosted.org/packages/71/48/592c70af94a67cafacd9c840ae2980f27d511dde2732a4c0dfac8f176ae8/ujson-5.13.0-graalpy312-graalpy250_312_native-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8bb53ef95d35875262b8d0aa28506ca612ddd07058bee2a90f609938e69dc801", size = 54447, upload-time = "2026-06-14T22:36:29.802Z" }, + { url = "https://files.pythonhosted.org/packages/f2/9d/2bb91e1e25a8584cb3b63544b9bd26f621173535c77ac6cae13bad8e7904/ujson-5.13.0-graalpy312-graalpy250_312_native-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fb296a0aa480ab88d895ddaa90372604c08ccc72323f02590612c775426ab413", size = 56066, upload-time = "2026-06-14T22:36:30.806Z" }, + { url = "https://files.pythonhosted.org/packages/76/ec/8e3802fc4a4e31e817b972bbb0e704a484d8c75ec349b3feb45fa9fb54c4/ujson-5.13.0-graalpy312-graalpy250_312_native-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2862f81af44b3a7e74c5d80caa118d736be1991ce6f1d5c723716fa403060cc6", size = 54938, upload-time = "2026-06-14T22:36:32.051Z" }, + { url = "https://files.pythonhosted.org/packages/4a/48/d0e3e511039b86fd1ecfe2bf761c800552d273ef8f19e71de93bf38a909e/ujson-5.13.0-graalpy312-graalpy250_312_native-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c16e07581172f08585b409246f4535dab13ee85af0e3d3cfa8684b653ca44fa8", size = 56115, upload-time = "2026-06-14T22:36:33.349Z" }, + { url = "https://files.pythonhosted.org/packages/81/b5/689613037fe691d18eae075cd141089f3a3156146be14512df92d8a9ae8f/ujson-5.13.0-graalpy312-graalpy250_312_native-win_amd64.whl", hash = "sha256:9bd0f2dd05937c3b089af316884de18c6f6182ddb8ffce597d2e7c7a9ba9f447", size = 41802, upload-time = "2026-06-14T22:36:34.523Z" }, ] [[package]] diff --git a/cli/src/api/app-meta.test.ts b/cli/src/api/app-meta.test.ts index ab7ab8a6ae1..a8ef958ce2f 100644 --- a/cli/src/api/app-meta.test.ts +++ b/cli/src/api/app-meta.test.ts @@ -1,13 +1,14 @@ import type { DifyMock } from '@test/fixtures/dify-mock/server' -import { mkdtemp, rm } from 'node:fs/promises' +import { mkdtemp, readFile, rm, writeFile } from 'node:fs/promises' import { tmpdir } from 'node:os' import { join } from 'node:path' import { startMock } from '@test/fixtures/dify-mock/server' import { testHttpClient } from '@test/fixtures/http-client' +import yaml from 'js-yaml' import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' import { loadAppInfoCache } from '@/cache/app-info' import { ENV_CACHE_DIR } from '@/store/dir' -import { CACHE_APP_INFO, getCache } from '@/store/manager' +import { CACHE_APP_INFO, cachePath, getCache } from '@/store/manager' import { FieldInfo, FieldParameters } from '@/types/app-meta' import { AppMetaClient } from './app-meta.js' import { AppsClient } from './apps.js' @@ -97,4 +98,40 @@ describe('AppMetaClient', () => { await client.get('app-1', [FieldInfo]) expect(spy).toHaveBeenCalledTimes(2) }) + + it('corrupt cache entry refetches; valid sibling stays cached; no throw', async () => { + const path = cachePath(dir, CACHE_APP_INFO) + const apps = new AppsClient(testHttpClient(mock.url, 'dfoa_test')) + + // Seed a real, production-serialized entry by fetching app-1 once (this + // calls cache.set → serialize, so we never hand-write the on-disk shape). + const seed = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + await new AppMetaClient({ apps, host: mock.url, cache: seed }).get('app-1', [FieldInfo]) + + // Reuse that serialized entry as a valid sibling; corrupt the app-1 slot. + const file = yaml.load(await readFile(path, 'utf8')) as { entries: Record } + const validEntry = file.entries[`${mock.url}::app-1`] + await writeFile( + path, + yaml.dump({ entries: { + [`${mock.url}::app-1`]: 'corrupted-string', + [`${mock.url}::sibling`]: validEntry, + } }), + 'utf8', + ) + + // Reload: app-1 dropped, sibling kept. + const cache = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + const spy = vi.spyOn(apps, 'describe') + const client = new AppMetaClient({ apps, host: mock.url, cache }) + + // app-1 corrupt → dropped → miss → refetched from the mock + const a = await client.get('app-1', [FieldInfo]) + expect(a.info?.id).toBe('app-1') + expect(spy).toHaveBeenCalledTimes(1) + + // sibling is the real serialized entry → served from cache, no network + await client.get('sibling', [FieldInfo]) + expect(spy).toHaveBeenCalledTimes(1) + }) }) diff --git a/cli/src/cache/app-info.test.ts b/cli/src/cache/app-info.test.ts index ae6f15f249d..59d8e762020 100644 --- a/cli/src/cache/app-info.test.ts +++ b/cli/src/cache/app-info.test.ts @@ -1,5 +1,5 @@ import type { AppMeta } from '@/types/app-meta' -import { mkdtemp, readFile, rm } from 'node:fs/promises' +import { mkdtemp, readFile, rm, writeFile } from 'node:fs/promises' import { tmpdir } from 'node:os' import { join } from 'node:path' import yaml from 'js-yaml' @@ -14,10 +14,10 @@ function appInfoPath(dir: string): string { return cachePath(dir, CACHE_APP_INFO) } -function metaInfoOnly(): AppMeta { +function metaInfoOnly(id = 'app-1'): AppMeta { return { info: { - id: 'app-1', + id, name: 'Greeter', description: '', mode: 'chat', @@ -101,12 +101,32 @@ describe('app-info disk cache', () => { }) it('corrupt cache file is treated as empty', async () => { - const { writeFile } = await import('node:fs/promises') await writeFile(appInfoPath(dir), ': : not valid yaml', 'utf8') const c = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) expect(c.get('h', 'app-1')).toBeUndefined() }) + it('drops a corrupt single entry but keeps valid siblings', async () => { + // Seed a real serialized entry via set() — no hand-authored on-disk shape. + const seed = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + await seed.set('h', 'app-2', metaInfoOnly('app-2')) + + // Inject a corrupt sibling alongside the real one. + const file = yaml.load(await readFile(appInfoPath(dir), 'utf8')) as { entries: Record } + file.entries['h::app-1'] = 'corrupted-string-not-object' + await writeFile(appInfoPath(dir), yaml.dump(file), 'utf8') + + const c = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + expect(c.get('h', 'app-1')).toBeUndefined() + expect(c.get('h', 'app-2')?.meta.info?.id).toBe('app-2') + }) + + it('treats a non-object entries map as empty', async () => { + await writeFile(appInfoPath(dir), yaml.dump({ entries: 'not-an-object' }), 'utf8') + const c = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + expect(c.get('h', 'app-1')).toBeUndefined() + }) + it('updates same key in place (no growth)', async () => { const c = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) await c.set('h', 'app-1', metaInfoOnly()) diff --git a/cli/src/cache/app-info.ts b/cli/src/cache/app-info.ts index 052de019fee..266a880efce 100644 --- a/cli/src/cache/app-info.ts +++ b/cli/src/cache/app-info.ts @@ -74,8 +74,18 @@ async function readEntries(store: Store): Promise...
blocks silently by default; with --think, thinking is printed to stderr.', default: false }), + 'think': Flags.boolean({ description: 'Show model thinking/reasoning when available — both inline ... blocks and separated reasoning streams. Hidden by default; with --think, thinking is printed to stderr.', default: false }), 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.TEXT], default: '' }), 'http-retry': httpRetryFlag, } diff --git a/cli/src/commands/run/app/_strategies/streaming-structured.ts b/cli/src/commands/run/app/_strategies/streaming-structured.ts index 3ed602a3410..de39bfda2e5 100644 --- a/cli/src/commands/run/app/_strategies/streaming-structured.ts +++ b/cli/src/commands/run/app/_strategies/streaming-structured.ts @@ -7,6 +7,7 @@ import { collect, HitlPauseError } from '@/commands/run/app/sse-collector' import { formatted, stringifyOutput } from '@/framework/output' import { handle, unhandle } from '@/sys/index' import { colorEnabled, colorScheme } from '@/sys/io/color' +import { reasoningBlocksFromMetadata } from '@/sys/io/reasoning' import { startSpinner } from '@/sys/io/spinner' import { extractThinkBlocks, filterThinkInOutputs, stripThinkBlocks } from '@/sys/io/think-filter' @@ -99,6 +100,13 @@ export class StreamingStructuredStrategy implements RunStrategy { } } + // Surface separated-mode reasoning (carried in message_end metadata) to stderr under --think. + if (ctx.think) { + const reasoningBlocks = reasoningBlocksFromMetadata(processedResp.metadata) + if (reasoningBlocks !== '') + deps.io.err.write(`${reasoningBlocks}\n`) + } + const respMode = typeof processedResp.mode === 'string' && processedResp.mode !== '' ? processedResp.mode : mode deps.io.out.write(stringifyOutput(formatted({ format, data: newAppRunObject(respMode, processedResp) }))) if (isText && CHAT_MODES.has(respMode)) { diff --git a/cli/src/commands/run/app/index.ts b/cli/src/commands/run/app/index.ts index 815d708d9d8..6b63e17e81e 100644 --- a/cli/src/commands/run/app/index.ts +++ b/cli/src/commands/run/app/index.ts @@ -35,7 +35,7 @@ export default class RunApp extends DifyCommand { 'workflow-id': Flags.string({ description: 'Pin to a specific published workflow version' }), 'workspace': Flags.string({ description: 'Workspace id (overrides DIFY_WORKSPACE_ID and stored default)' }), 'stream': Flags.boolean({ description: 'Print output live as tokens/events arrive (default: collect and print at end)', default: false }), - 'think': Flags.boolean({ description: 'Show model thinking/reasoning when available. Strips ... blocks silently by default; with --think, thinking is printed to stderr.', default: false }), + 'think': Flags.boolean({ description: 'Show model thinking/reasoning when available — both inline ... blocks and separated reasoning streams. Hidden by default; with --think, thinking is printed to stderr.', default: false }), 'retry-on-limit': Flags.boolean({ description: 'On a 429 rate limit, wait and retry this POST (bounded) instead of failing immediately. Off by default since running an app is not idempotent.', default: false }), 'http-retry': httpRetryFlag, 'output': Flags.outputFormat({ options: [OutputFormat.JSON, OutputFormat.YAML, OutputFormat.TEXT], default: '' }), diff --git a/cli/src/commands/run/app/run.test.ts b/cli/src/commands/run/app/run.test.ts index 57b02aeb47d..a4e201c8ee1 100644 --- a/cli/src/commands/run/app/run.test.ts +++ b/cli/src/commands/run/app/run.test.ts @@ -203,6 +203,85 @@ describe('runApp', () => { expect(io.errBuf()).toContain('secret reasoning') }) + it('--stream chat --think: routes separated reasoning to stderr, clean answer to stdout', async () => { + mock.setScenario('chat-reasoning') + const io = bufferStreams() + const cache = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + await runApp( + { appId: 'app-1', message: 'hi', stream: true, think: true }, + { active: active(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache }, + ) + expect(io.outBuf()).toContain('final answer') + expect(io.outBuf()).not.toContain('secret reasoning') + expect(io.errBuf()).toContain('') + expect(io.errBuf()).toContain('secret reasoning') + }) + + it('--stream chat without --think: separated reasoning stays hidden', async () => { + mock.setScenario('chat-reasoning') + const io = bufferStreams() + const cache = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + await runApp( + { appId: 'app-1', message: 'hi', stream: true }, + { active: active(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache }, + ) + expect(io.outBuf()).toContain('final answer') + expect(io.errBuf()).not.toContain('secret reasoning') + }) + + it('chat -o json --think: echoes separated reasoning to stderr, persists it in metadata', async () => { + mock.setScenario('chat-reasoning') + const io = bufferStreams() + const cache = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + await runApp( + { appId: 'app-1', message: 'hi', format: 'json', think: true }, + { active: active(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache }, + ) + expect(io.errBuf()).toContain('secret reasoning') + const parsed = JSON.parse(io.outBuf()) as { answer: string, metadata: { reasoning: Record } } + expect(parsed.answer).toBe('final answer') + expect(parsed.metadata.reasoning).toEqual({ 'llm-1': 'secret reasoning' }) + }) + + it('--stream workflow --think: routes separated reasoning to stderr, clean outputs to stdout', async () => { + mock.setScenario('workflow-reasoning') + const io = bufferStreams() + const cache = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + await runApp( + { appId: 'app-2', inputs: { x: '1' }, stream: true, think: true }, + { active: active(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache }, + ) + expect(io.errBuf()).toContain('') + expect(io.errBuf()).toContain('secret reasoning') + expect(io.outBuf()).toContain('final answer') + expect(io.outBuf()).not.toContain('secret reasoning') + }) + + it('workflow -o json --think: echoes reasoning to stderr, accumulates into metadata.reasoning', async () => { + mock.setScenario('workflow-reasoning') + const io = bufferStreams() + const cache = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + await runApp( + { appId: 'app-2', inputs: { x: '1' }, format: 'json', think: true }, + { active: active(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache }, + ) + expect(io.errBuf()).toContain('secret reasoning') + const parsed = JSON.parse(io.outBuf()) as { metadata: { reasoning: Record } } + expect(parsed.metadata.reasoning).toEqual({ 'llm-1': 'secret reasoning' }) + }) + + it('--stream workflow without --think: reasoning stays hidden', async () => { + mock.setScenario('workflow-reasoning') + const io = bufferStreams() + const cache = await loadAppInfoCache({ store: getCache(CACHE_APP_INFO) }) + await runApp( + { appId: 'app-2', inputs: { x: '1' }, stream: true }, + { active: active(), http: testHttpClient(mock.url, 'dfoa_test'), host: mock.url, io, cache }, + ) + expect(io.outBuf()).toContain('final answer') + expect(io.errBuf()).not.toContain('secret reasoning') + }) + it('stream-error scenario: error event surfaces typed BaseError', async () => { mock.setScenario('stream-error') const io = bufferStreams() diff --git a/cli/src/commands/run/app/sse-collector.test.ts b/cli/src/commands/run/app/sse-collector.test.ts index e92d3694d3f..d26abb5608b 100644 --- a/cli/src/commands/run/app/sse-collector.test.ts +++ b/cli/src/commands/run/app/sse-collector.test.ts @@ -59,6 +59,41 @@ describe('collect — chat', () => { }) }) +describe('collect — chat separated reasoning', () => { + function reasoningEvent(reasoning: string, isFinal: boolean) { + return ev('reasoning_chunk', { data: { message_id: 'm1', reasoning, node_id: 'llm-1', is_final: isFinal } }) + } + + it('backfills metadata.reasoning from live deltas when the server omits it', async () => { + const got = await collect(iterOf( + reasoningEvent('pon', false), + reasoningEvent('dering', true), + ev('message', { message_id: 'm1', answer: 'answer' }), + ev('message_end', { metadata: { usage: { tokens: 3 } } }), + ), 'advanced-chat') + expect(got.answer).toBe('answer') + expect((got.metadata as { reasoning?: unknown }).reasoning).toEqual({ 'llm-1': 'pondering' }) + expect((got.metadata as { usage?: unknown }).usage).toEqual({ tokens: 3 }) + }) + + it('keeps the server-persisted reasoning over live deltas', async () => { + const got = await collect(iterOf( + reasoningEvent('live', true), + ev('message', { answer: 'a' }), + ev('message_end', { metadata: { reasoning: { 'llm-1': 'persisted' } } }), + ), 'advanced-chat') + expect((got.metadata as { reasoning?: unknown }).reasoning).toEqual({ 'llm-1': 'persisted' }) + }) + + it('leaves metadata untouched when there is no reasoning at all', async () => { + const got = await collect(iterOf( + ev('message', { answer: 'a' }), + ev('message_end', { metadata: { usage: { tokens: 1 } } }), + ), 'advanced-chat') + expect((got.metadata as { reasoning?: unknown }).reasoning).toBeUndefined() + }) +}) + describe('collect — agent-chat', () => { it('captures agent_thoughts', async () => { const got = await collect(iterOf( @@ -97,6 +132,39 @@ describe('collect — workflow', () => { }) }) +describe('collect — workflow separated reasoning', () => { + function wfReasoning(reasoning: string, nodeId: string, isFinal: boolean) { + return ev('reasoning_chunk', { data: { reasoning, node_id: nodeId, is_final: isFinal } }) + } + + it('accumulates reasoning_chunk per node into metadata.reasoning', async () => { + const got = await collect(iterOf( + ev('node_started', { id: 'llm-1' }), + wfReasoning('pon', 'llm-1', false), + wfReasoning('dering', 'llm-1', true), + ev('workflow_finished', { data: { status: 'succeeded', outputs: { result: 'clean' } } }), + ), 'workflow') + expect((got.data as { outputs: { result: string } }).outputs.result).toBe('clean') + expect((got.metadata as { reasoning?: unknown }).reasoning).toEqual({ 'llm-1': 'pondering' }) + }) + + it('keys reasoning by node, leaves metadata absent when there is none', async () => { + const got = await collect(iterOf( + ev('workflow_finished', { data: { status: 'succeeded', outputs: { result: 'clean' } } }), + ), 'workflow') + expect((got.metadata as { reasoning?: unknown } | undefined)?.reasoning).toBeUndefined() + }) + + it('merges reasoning into metadata already carried by workflow_finished', async () => { + const got = await collect(iterOf( + wfReasoning('think', 'llm-1', true), + ev('workflow_finished', { data: { status: 'succeeded' }, metadata: { usage: { tokens: 7 } } }), + ), 'workflow') + expect((got.metadata as { reasoning?: unknown }).reasoning).toEqual({ 'llm-1': 'think' }) + expect((got.metadata as { usage?: unknown }).usage).toEqual({ tokens: 7 }) + }) +}) + describe('collect — error event', () => { it('throws BaseError when error event arrives', async () => { await expect(collect(iterOf( diff --git a/cli/src/commands/run/app/sse-collector.ts b/cli/src/commands/run/app/sse-collector.ts index 043a9690d5f..b3b3a33b06c 100644 --- a/cli/src/commands/run/app/sse-collector.ts +++ b/cli/src/commands/run/app/sse-collector.ts @@ -2,6 +2,7 @@ import type { BaseError } from '@/errors/base' import type { SseEvent } from '@/http/sse' import { HttpClientError, newError } from '@/errors/base' import { ErrorCode } from '@/errors/codes' +import { accumulateReasoning, parseReasoningChunk } from '@/sys/io/reasoning' import { RUN_MODES } from './handlers' export type HitlPauseData = { @@ -67,6 +68,7 @@ class ChatCollector implements Collector { private base: Record = {} private metadata: Record | undefined private thoughts: unknown[] = [] + private readonly reasoning: Record = {} private readonly mode: string private readonly isAgent: boolean constructor(mode: string, isAgent: boolean) { @@ -84,6 +86,13 @@ class ChatCollector implements Collector { copyScalar(this.base, c, ['id', 'conversation_id', 'message_id', 'task_id', 'created_at']) return } + // Accumulate separated-mode reasoning deltas per LLM node. + case 'reasoning_chunk': { + const chunk = parseReasoningChunk(c) + if (chunk !== undefined) + accumulateReasoning(this.reasoning, chunk) + return + } case 'agent_thought': this.thoughts.push(c) return @@ -98,12 +107,23 @@ class ChatCollector implements Collector { const out: Record = { mode: this.mode, answer: this.answer, ...this.base } if (this.metadata !== undefined) out.metadata = this.metadata + // Fall back to live deltas only when the server didn't persist reasoning in metadata. + if (Object.keys(this.reasoning).length > 0 && !hasReasoning(this.metadata)) + out.metadata = { ...(this.metadata ?? {}), reasoning: this.reasoning } if (this.isAgent || this.thoughts.length > 0) out.agent_thoughts = this.thoughts return out } } +function hasReasoning(metadata: Record | undefined): boolean { + const reasoning = metadata?.reasoning + return reasoning !== null + && typeof reasoning === 'object' + && !Array.isArray(reasoning) + && Object.keys(reasoning as object).length > 0 +} + class CompletionCollector implements Collector { private answer = '' private base: Record = {} @@ -133,14 +153,29 @@ class CompletionCollector implements Collector { class WorkflowCollector implements Collector { private final: Record | undefined + private readonly reasoning: Record = {} consume(ev: SseEvent): void { + if (ev.name === 'reasoning_chunk') { + const chunk = parseReasoningChunk(parseJson(ev.data)) + if (chunk !== undefined) + accumulateReasoning(this.reasoning, chunk) + return + } if (ev.name !== 'workflow_finished') return this.final = parseJson(ev.data) } finalize(): Record { - return { mode: RUN_MODES.Workflow, ...(this.final ?? {}) } + const out: Record = { mode: RUN_MODES.Workflow, ...(this.final ?? {}) } + // Workflow runs don't persist reasoning; surface live deltas under metadata.reasoning. + if (Object.keys(this.reasoning).length > 0) { + const existing = (out.metadata !== null && typeof out.metadata === 'object' && !Array.isArray(out.metadata)) + ? out.metadata as Record + : undefined + out.metadata = { ...(existing ?? {}), reasoning: this.reasoning } + } + return out } } diff --git a/cli/src/commands/run/app/stream-handlers.test.ts b/cli/src/commands/run/app/stream-handlers.test.ts index bab45144a01..ec9cd957bb5 100644 --- a/cli/src/commands/run/app/stream-handlers.test.ts +++ b/cli/src/commands/run/app/stream-handlers.test.ts @@ -37,6 +37,42 @@ describe('streamPrinterFor — chat', () => { }) }) +function reasoningEvent(reasoning: string, isFinal: boolean) { + return ev('reasoning_chunk', { data: { message_id: 'm1', reasoning, node_id: 'llm-1', is_final: isFinal } }) +} + +describe('streamPrinterFor — chat separated reasoning', () => { + it('think: true frames reasoning_chunk deltas to stderr, answer stays clean on stdout', () => { + const sp = streamPrinterFor('advanced-chat', true) + const cap = captures() + sp.onEvent(cap.out, cap.err, reasoningEvent('pon', false)) + sp.onEvent(cap.out, cap.err, reasoningEvent('dering', false)) + sp.onEvent(cap.out, cap.err, reasoningEvent('', true)) + sp.onEvent(cap.out, cap.err, ev('message', { conversation_id: 'c1', answer: 'final answer' })) + sp.onEnd(cap.out, cap.err) + expect(cap.errBuf()).toContain(' [llm-1]\npondering') + expect(cap.outBuf()).toBe('final answer\n') + }) + + it('think: false ignores reasoning_chunk entirely', () => { + const sp = streamPrinterFor('advanced-chat', false) + const cap = captures() + sp.onEvent(cap.out, cap.err, reasoningEvent('secret', true)) + sp.onEvent(cap.out, cap.err, ev('message', { answer: 'hi' })) + sp.onEnd(cap.out, cap.err) + expect(cap.errBuf()).not.toContain('secret') + expect(cap.outBuf()).toBe('hi\n') + }) + + it('closes an unterminated reasoning block on stream end', () => { + const sp = streamPrinterFor('advanced-chat', true) + const cap = captures() + sp.onEvent(cap.out, cap.err, reasoningEvent('thinking', false)) + sp.onEnd(cap.out, cap.err) + expect(cap.errBuf()).toContain(' [llm-1]\nthinking') + }) +}) + describe('streamPrinterFor — agent-chat', () => { it('writes agent_thought to stderr', () => { const sp = streamPrinterFor('agent-chat') @@ -105,6 +141,62 @@ describe('streamPrinterFor — workflow think filtering', () => { }) }) +// Workflow reasoning_chunk events carry no message_id (workflow runs have no message). +function wfReasoning(reasoning: string, nodeId: string, isFinal: boolean) { + return ev('reasoning_chunk', { data: { reasoning, node_id: nodeId, is_final: isFinal } }) +} + +describe('streamPrinterFor — workflow separated reasoning', () => { + it('think: true frames reasoning_chunk to stderr, outputs stay clean on stdout', () => { + const sp = streamPrinterFor('workflow', true) + const cap = captures() + sp.onEvent(cap.out, cap.err, ev('node_started', { id: 'llm-1', title: 'LLM' })) + sp.onEvent(cap.out, cap.err, wfReasoning('pon', 'llm-1', false)) + sp.onEvent(cap.out, cap.err, wfReasoning('dering', 'llm-1', false)) + sp.onEvent(cap.out, cap.err, wfReasoning('', 'llm-1', true)) + sp.onEvent(cap.out, cap.err, ev('workflow_finished', { data: { outputs: { result: 'clean answer' } } })) + sp.onEnd(cap.out, cap.err) + // node title precedes the reasoning block for attribution + expect(cap.errBuf()).toContain('→ LLM') + expect(cap.errBuf()).toContain(' [llm-1]\npondering') + const parsed = JSON.parse(cap.outBuf().trim()) as { result: string } + expect(parsed.result).toBe('clean answer') + }) + + it('think: false drops reasoning_chunk entirely', () => { + const sp = streamPrinterFor('workflow', false) + const cap = captures() + sp.onEvent(cap.out, cap.err, wfReasoning('secret', 'llm-1', true)) + sp.onEvent(cap.out, cap.err, ev('workflow_finished', { data: { outputs: { result: 'ok' } } })) + sp.onEnd(cap.out, cap.err) + expect(cap.errBuf()).not.toContain('secret') + const parsed = JSON.parse(cap.outBuf().trim()) as { result: string } + expect(parsed.result).toBe('ok') + }) + + it('closes an unterminated reasoning block on stream end', () => { + const sp = streamPrinterFor('workflow', true) + const cap = captures() + sp.onEvent(cap.out, cap.err, wfReasoning('thinking', 'llm-1', false)) + sp.onEnd(cap.out, cap.err) + expect(cap.errBuf()).toContain(' [llm-1]\nthinking') + }) + + it('keeps interleaved parallel-node reasoning in separate node-tagged blocks', () => { + const sp = streamPrinterFor('workflow', true) + const cap = captures() + sp.onEvent(cap.out, cap.err, wfReasoning('a1', 'llm-1', false)) + sp.onEvent(cap.out, cap.err, wfReasoning('b1', 'llm-2', false)) + sp.onEvent(cap.out, cap.err, wfReasoning('a2', 'llm-1', true)) + sp.onEvent(cap.out, cap.err, wfReasoning('b2', 'llm-2', true)) + sp.onEvent(cap.out, cap.err, ev('workflow_finished', { data: { outputs: { result: 'ok' } } })) + sp.onEnd(cap.out, cap.err) + expect(cap.errBuf()).toBe( + ' [llm-1]\na1\n [llm-2]\nb1\n [llm-1]\na2\n [llm-2]\nb2\n', + ) + }) +}) + describe('streamPrinterFor — unknown mode', () => { it('throws', () => { expect(() => streamPrinterFor('whatever')).toThrow() diff --git a/cli/src/commands/run/app/stream-handlers.ts b/cli/src/commands/run/app/stream-handlers.ts index 1955d5996bf..7ce96cf53c6 100644 --- a/cli/src/commands/run/app/stream-handlers.ts +++ b/cli/src/commands/run/app/stream-handlers.ts @@ -4,6 +4,7 @@ import type { SseEvent } from '@/http/sse' import { newError } from '@/errors/base' import { ErrorCode } from '@/errors/codes' import { colorEnabled, colorScheme } from '@/sys/io/color' +import { parseReasoningChunk, ReasoningChunkRenderer } from '@/sys/io/reasoning' import { filterThinkInOutputs, ThinkChunkFilter } from '@/sys/io/think-filter' import { RUN_MODES } from './handlers' import { HitlPauseError } from './sse-collector' @@ -43,9 +44,12 @@ function handleCommonEvents(ev: SseEvent): boolean { class ChatStreamPrinter implements StreamPrinter { private convoId = '' private readonly filter: ThinkChunkFilter + private readonly reasoning = new ReasoningChunkRenderer() + private readonly think: boolean private readonly isTTY: boolean constructor(think: boolean, isTTY = false) { this.filter = new ThinkChunkFilter(think) + this.think = think this.isTTY = isTTY } @@ -62,6 +66,15 @@ class ChatStreamPrinter implements StreamPrinter { this.convoId = c.conversation_id return } + // Stream separated-mode reasoning to stderr under --think. + case 'reasoning_chunk': { + if (!this.think) + return + const chunk = parseReasoningChunk(c) + if (chunk !== undefined) + this.reasoning.push(chunk, errOut) + return + } case 'agent_thought': if (typeof c.thought === 'string' && c.thought !== '') errOut.write(`thought: ${c.thought}\n`) @@ -73,6 +86,7 @@ class ChatStreamPrinter implements StreamPrinter { } onEnd(out: NodeJS.WritableStream, errOut: NodeJS.WritableStream): void { + this.reasoning.flush(errOut) this.filter.flush(out, errOut) out.write('\n') if (this.convoId !== '') { @@ -106,6 +120,7 @@ class CompletionStreamPrinter implements StreamPrinter { class WorkflowStreamPrinter implements StreamPrinter { private final: Record | undefined + private readonly reasoning = new ReasoningChunkRenderer() private readonly think: boolean constructor(think: boolean) { this.think = think @@ -124,6 +139,15 @@ class WorkflowStreamPrinter implements StreamPrinter { errOut.write(`→ ${title}\n`) return } + // Stream separated-mode reasoning to stderr under --think; the prior → title attributes the node. + case 'reasoning_chunk': { + if (!this.think) + return + const chunk = parseReasoningChunk(c) + if (chunk !== undefined) + this.reasoning.push(chunk, errOut) + return + } case 'node_finished': { const status = typeof c.status === 'string' ? c.status : '' if (status !== '' && status !== 'succeeded') { @@ -138,6 +162,7 @@ class WorkflowStreamPrinter implements StreamPrinter { } onEnd(out: NodeJS.WritableStream, errOut: NodeJS.WritableStream): void { + this.reasoning.flush(errOut) if (this.final === undefined) return const data = this.final.data diff --git a/cli/src/framework/run.test.ts b/cli/src/framework/run.test.ts index c41913961f5..185072b9cbe 100644 --- a/cli/src/framework/run.test.ts +++ b/cli/src/framework/run.test.ts @@ -212,18 +212,30 @@ describe('run() catch routing', () => { expect(result.exit).toBe(ExitCode.Generic) }) - it('falls through to generic Error branch and exits 1', async () => { + it('routes non-BaseError to JSON envelope with -o json (exit 1)', async () => { + class Throwing extends Command { + async run(_argv: string[]) { + throw new Error('boom') + } + } + const result = await captureRun(makeTree(Throwing), ['cmd', '-o', 'json']) + expect(result.stderr).toBe(`${JSON.stringify({ error: { code: 'unknown', message: 'boom' } })}\n`) + expect(result.exit).toBe(ExitCode.Generic) + expect(result.stdout).toBe('') + }) + + it('wraps a generic Error into the human unknown form and exits 1', async () => { class Throwing extends Command { async run(_argv: string[]) { throw new Error('oops') } } const result = await captureRun(makeTree(Throwing), ['cmd']) - expect(result.stderr).toBe('oops\n') - expect(result.exit).toBe(1) + expect(result.stderr).toBe('unknown: oops\n') + expect(result.exit).toBe(ExitCode.Generic) }) - it('handles non-Error throw via String() coercion', async () => { + it('wraps a non-Error throw via String() coercion into unknown form', async () => { class Throwing extends Command { async run(_argv: string[]) { // eslint-disable-next-line no-throw-literal @@ -231,8 +243,52 @@ describe('run() catch routing', () => { } } const result = await captureRun(makeTree(Throwing), ['cmd']) - expect(result.stderr).toBe('plain string\n') - expect(result.exit).toBe(1) + expect(result.stderr).toBe('unknown: plain string\n') + expect(result.exit).toBe(ExitCode.Generic) + }) + + it('exits 0 on EPIPE without writing an error envelope', async () => { + class Throwing extends Command { + async run(_argv: string[]) { + throw Object.assign(new Error('broken pipe'), { code: 'EPIPE' }) + } + } + // process.exit is typed `never`; stub it to halt (throw) like the real call, + // so the EPIPE early-exit doesn't fall through to the envelope path. + let exitCode: number | undefined + let stderr = '' + const origExit = process.exit.bind(process) + const origStderr = process.stderr.write.bind(process.stderr) + process.exit = ((code?: number) => { + exitCode = code + throw new Error('__exit__') + }) as typeof process.exit + process.stderr.write = ((chunk: string | Uint8Array) => { + stderr += typeof chunk === 'string' ? chunk : new TextDecoder().decode(chunk) + return true + }) as typeof process.stderr.write + try { + await run(makeTree(Throwing), ['cmd', '-o', 'json']) + } + catch (e) { + expect((e as Error).message).toBe('__exit__') + } + finally { + process.exit = origExit + process.stderr.write = origStderr + } + expect(exitCode).toBe(0) + expect(stderr).toBe('') + }) + + it('preserves RateLimited semantic exit code through the collapsed catch', async () => { + class Throwing extends Command { + async run(_argv: string[]) { + throw newError(ErrorCode.RateLimited, 'slow down') + } + } + const result = await captureRun(makeTree(Throwing), ['cmd']) + expect(result.exit).toBe(ExitCode.RateLimited) }) it('does not call process.exit when command runs successfully', async () => { diff --git a/cli/src/framework/run.ts b/cli/src/framework/run.ts index 5682a9e1887..9bbd86aa5e2 100644 --- a/cli/src/framework/run.ts +++ b/cli/src/framework/run.ts @@ -1,5 +1,5 @@ import type { CommandTree } from './registry' -import { BaseError } from '@/errors/base' +import { BaseError, unknownError } from '@/errors/base' import { formatErrorForCli } from '@/errors/format' import { findTopic } from '@/help/topics' import { formatCommandList, formatHelp, formatTopic, formatTopLevelHelp } from './help' @@ -106,19 +106,12 @@ export async function run(tree: CommandTree, argv: string[]): Promise { catch (err) { if ((err as NodeJS.ErrnoException).code === 'EPIPE') process.exit(0) - if (err instanceof BaseError) { - const format = sniffOutputFormat(argv) - process.stderr.write(`${formatErrorForCli(err, { format, isErrTTY: process.stderr.isTTY })}\n`) - process.exit(err.exit()) - return - } - if (err instanceof Error) { - process.stderr.write(`${err.message}\n`) - process.exit(1) - return - } - process.stderr.write(`${String(err)}\n`) - process.exit(1) + const e = err instanceof BaseError + ? err + : unknownError(err instanceof Error ? err.message : String(err), err) + const format = sniffOutputFormat(argv) + process.stderr.write(`${formatErrorForCli(e, { format, isErrTTY: process.stderr.isTTY })}\n`) + process.exit(e.exit()) } } diff --git a/cli/src/sys/io/reasoning.test.ts b/cli/src/sys/io/reasoning.test.ts new file mode 100644 index 00000000000..1f6887a6108 --- /dev/null +++ b/cli/src/sys/io/reasoning.test.ts @@ -0,0 +1,128 @@ +import { Buffer } from 'node:buffer' +import { PassThrough } from 'node:stream' +import { describe, expect, it } from 'vitest' +import { + accumulateReasoning, + formatReasoningBlocks, + parseReasoningChunk, + reasoningBlocksFromMetadata, + ReasoningChunkRenderer, +} from './reasoning' + +function capture(): { err: PassThrough, errBuf: () => string } { + const err = new PassThrough() + const ec: Buffer[] = [] + err.on('data', d => ec.push(d as Buffer)) + return { err, errBuf: () => Buffer.concat(ec).toString('utf-8') } +} + +describe('parseReasoningChunk', () => { + it('reads the payload nested under data', () => { + expect(parseReasoningChunk({ data: { reasoning: 'hi', node_id: 'llm-1', is_final: true } })) + .toEqual({ reasoning: 'hi', nodeId: 'llm-1', isFinal: true }) + }) + + it('defaults missing/wrong-typed fields', () => { + expect(parseReasoningChunk({ data: {} })).toEqual({ reasoning: '', nodeId: '', isFinal: false }) + }) + + it('returns undefined when data is absent or not an object', () => { + expect(parseReasoningChunk({})).toBeUndefined() + expect(parseReasoningChunk({ data: null })).toBeUndefined() + expect(parseReasoningChunk({ data: ['x'] })).toBeUndefined() + }) +}) + +describe('ReasoningChunkRenderer', () => { + it('frames streamed deltas with a node-tagged open/close on the terminal marker', () => { + const cap = capture() + const r = new ReasoningChunkRenderer() + r.push({ reasoning: 'pon', nodeId: 'llm-1', isFinal: false }, cap.err) + r.push({ reasoning: 'dering', nodeId: 'llm-1', isFinal: false }, cap.err) + r.push({ reasoning: '', nodeId: 'llm-1', isFinal: true }, cap.err) + expect(cap.errBuf()).toBe(' [llm-1]\npondering\n') + }) + + it('emits separate node-tagged blocks per node', () => { + const cap = capture() + const r = new ReasoningChunkRenderer() + r.push({ reasoning: 'a', nodeId: 'n1', isFinal: true }, cap.err) + r.push({ reasoning: 'b', nodeId: 'n2', isFinal: true }, cap.err) + expect(cap.errBuf()).toBe(' [n1]\na\n [n2]\nb\n') + }) + + it('tags each block with its node id so interleaved fragments stay distinguishable', () => { + const cap = capture() + const r = new ReasoningChunkRenderer() + r.push({ reasoning: 'a1', nodeId: 'n1', isFinal: false }, cap.err) + r.push({ reasoning: 'b1', nodeId: 'n2', isFinal: false }, cap.err) + r.push({ reasoning: 'a2', nodeId: 'n1', isFinal: true }, cap.err) + r.push({ reasoning: 'b2', nodeId: 'n2', isFinal: true }, cap.err) + expect(cap.errBuf()).toBe( + ' [n1]\na1\n [n2]\nb1\n [n1]\na2\n [n2]\nb2\n', + ) + }) + + it('omits the tag when the chunk carries no node id', () => { + const cap = capture() + const r = new ReasoningChunkRenderer() + r.push({ reasoning: 'plain', nodeId: '', isFinal: true }, cap.err) + expect(cap.errBuf()).toBe('\nplain\n') + }) + + it('flush closes a block left open by a truncated stream', () => { + const cap = capture() + const r = new ReasoningChunkRenderer() + r.push({ reasoning: 'half', nodeId: 'n1', isFinal: false }, cap.err) + r.flush(cap.err) + expect(cap.errBuf()).toBe(' [n1]\nhalf\n') + }) + + it('a lone terminal marker with no reasoning emits nothing', () => { + const cap = capture() + const r = new ReasoningChunkRenderer() + r.push({ reasoning: '', nodeId: 'n1', isFinal: true }, cap.err) + expect(cap.errBuf()).toBe('') + }) +}) + +describe('accumulateReasoning', () => { + it('appends deltas per node, falling back to "_" for a missing nodeId', () => { + const acc: Record = {} + accumulateReasoning(acc, { reasoning: 'a', nodeId: 'n1', isFinal: false }) + accumulateReasoning(acc, { reasoning: 'b', nodeId: 'n1', isFinal: false }) + accumulateReasoning(acc, { reasoning: 'x', nodeId: '', isFinal: false }) + expect(acc).toEqual({ n1: 'ab', _: 'x' }) + }) + + it('ignores empty reasoning', () => { + const acc: Record = {} + accumulateReasoning(acc, { reasoning: '', nodeId: 'n1', isFinal: true }) + expect(acc).toEqual({}) + }) +}) + +describe('formatReasoningBlocks', () => { + it('frames and trims each node, joined by a separator', () => { + expect(formatReasoningBlocks({ n1: ' one ', n2: 'two' })) + .toBe('\none\n\n---\n\ntwo\n') + }) + + it('skips empty entries and returns empty for no reasoning', () => { + expect(formatReasoningBlocks({ n1: ' ' })).toBe('') + expect(formatReasoningBlocks({})).toBe('') + }) +}) + +describe('reasoningBlocksFromMetadata', () => { + it('extracts reasoning from a metadata object', () => { + expect(reasoningBlocksFromMetadata({ reasoning: { n1: 'why' } })) + .toBe('\nwhy\n') + }) + + it('returns empty for tagged mode (empty reasoning) and malformed input', () => { + expect(reasoningBlocksFromMetadata({ reasoning: {} })).toBe('') + expect(reasoningBlocksFromMetadata(undefined)).toBe('') + expect(reasoningBlocksFromMetadata({ usage: { tokens: 1 } })).toBe('') + }) +}) diff --git a/cli/src/sys/io/reasoning.ts b/cli/src/sys/io/reasoning.ts new file mode 100644 index 00000000000..2a88e3ae32b --- /dev/null +++ b/cli/src/sys/io/reasoning.ts @@ -0,0 +1,99 @@ +// Renders "separated"-mode reasoning (streamed on its own `reasoning_chunk` SSE +// channel) to stderr, so --think matches inline (see think-filter.ts). + +const THINK_OPEN = '' +const THINK_CLOSE = '' + +export type ReasoningChunk = { + reasoning: string + nodeId: string + isFinal: boolean +} + +// reasoning_chunk nests its payload under `data` (not top-level like `message`). +export function parseReasoningChunk(parsed: Record): ReasoningChunk | undefined { + const data = parsed.data + if (data === null || typeof data !== 'object' || Array.isArray(data)) + return undefined + const rec = data as Record + return { + reasoning: typeof rec.reasoning === 'string' ? rec.reasoning : '', + nodeId: typeof rec.node_id === 'string' ? rec.node_id : '', + isFinal: rec.is_final === true, + } +} + +// Bucket key for a chunk; falls back to a single bucket so live rendering and +// buffered collection key reasoning the same way. +export function reasoningKey(chunk: ReasoningChunk): string { + return chunk.nodeId !== '' ? chunk.nodeId : '_' +} + +// Appends a reasoning delta to a per-node accumulator. +export function accumulateReasoning(acc: Record, chunk: ReasoningChunk): void { + if (chunk.reasoning === '') + return + const key = reasoningKey(chunk) + acc[key] = (acc[key] ?? '') + chunk.reasoning +} + +// Frames a live reasoning stream into stderr: on the first delta, +// raw deltas thereafter, on is_final. Parallel branches can interleave +// chunks from different nodes on one stream, so it keeps at most one block open, +// switches blocks on node change, and tags each block with its node id so the +// interleaved fragments stay distinguishable. +export class ReasoningChunkRenderer { + private openNode: string | undefined + + push(chunk: ReasoningChunk, errOut: NodeJS.WritableStream): void { + const key = reasoningKey(chunk) + if (chunk.reasoning !== '') { + if (this.openNode !== key) { + this.closeActive(errOut) + errOut.write(chunk.nodeId !== '' ? `${THINK_OPEN} [${chunk.nodeId}]\n` : `${THINK_OPEN}\n`) + this.openNode = key + } + errOut.write(chunk.reasoning) + } + if (chunk.isFinal && this.openNode === key) + this.closeActive(errOut) + } + + // Close a block left open by a truncated stream. + flush(errOut: NodeJS.WritableStream): void { + this.closeActive(errOut) + } + + private closeActive(errOut: NodeJS.WritableStream): void { + if (this.openNode === undefined) + return + errOut.write(`${THINK_CLOSE}\n`) + this.openNode = undefined + } +} + +// Frames fully-buffered reasoning (one entry per LLM node id) into blocks. +export function formatReasoningBlocks(reasoning: Record): string { + const blocks: string[] = [] + for (const text of Object.values(reasoning)) { + const trimmed = text.trim() + if (trimmed !== '') + blocks.push(`${THINK_OPEN}\n${trimmed}\n${THINK_CLOSE}`) + } + return blocks.join('\n---\n') +} + +// Frames per-node reasoning from a message_end `metadata` object; '' when absent. +export function reasoningBlocksFromMetadata(metadata: unknown): string { + if (metadata === null || typeof metadata !== 'object' || Array.isArray(metadata)) + return '' + const reasoning = (metadata as Record).reasoning + if (reasoning === null || typeof reasoning !== 'object' || Array.isArray(reasoning)) + return '' + const map: Record = {} + for (const [key, value] of Object.entries(reasoning as Record)) { + if (typeof value === 'string') + map[key] = value + } + return formatReasoningBlocks(map) +} diff --git a/cli/test/e2e/.env.e2e.example b/cli/test/e2e/.env.e2e.example index 53fbedadf31..85f2219daeb 100644 --- a/cli/test/e2e/.env.e2e.example +++ b/cli/test/e2e/.env.e2e.example @@ -67,3 +67,15 @@ DIFY_E2E_PASSWORD= # DIFY_E2E_HITL_SINGLE_ACTION_APP_ID= # DIFY_E2E_HITL_MULTI_NODE_APP_ID= # DIFY_E2E_WS2_APP_ID= + +# ── Separated-mode reasoning suite (opt-in) ───────────────────────────────── +# run-app-reasoning.e2e.ts is skipped unless DIFY_E2E_REASONING_APP_ID resolves. +# It needs a chatflow whose LLM node uses reasoning_format=separated AND a +# workspace with a default chat model configured. +# +# Either point at an existing app: +# DIFY_E2E_REASONING_APP_ID= +# +# …or auto-provision reasoning-chat.yml (→ app name "reasoning-bot"). Off by +# default so the shared bootstrap stays free of any model dependency. +# DIFY_E2E_REASONING_PROVISION=1 diff --git a/cli/test/e2e/README.md b/cli/test/e2e/README.md index fd3e9182507..21635b0f6f9 100644 --- a/cli/test/e2e/README.md +++ b/cli/test/e2e/README.md @@ -39,11 +39,12 @@ test/e2e/ │ ├── describe-app.e2e.ts — describe app │ └── get-app-all-workspaces.e2e.ts — get app -A ([EE] multi-workspace cases) └── run/ - ├── run-app-basic.e2e.ts — basic run, -o json, --inputs, streaming, - │ conversation, CI mode - ├── run-app-streaming.e2e.ts — Ctrl+C / error-event / chunk timing - ├── run-app-file.e2e.ts — --file upload (local + remote URL) - └── run-app-hitl.e2e.ts — HITL pause + resume + ├── run-app-basic.e2e.ts — basic run, -o json, --inputs, streaming, + │ conversation, CI mode + ├── run-app-streaming.e2e.ts — Ctrl+C / error-event / chunk timing + ├── run-app-file.e2e.ts — --file upload (local + remote URL) + ├── run-app-reasoning.e2e.ts — separated-mode reasoning (--think); opt-in + └── run-app-hitl.e2e.ts — HITL pause + resume ``` ## Edition support @@ -137,6 +138,24 @@ global-setup will: | `DIFY_E2E_HITL_SINGLE_ACTION_APP_ID` | | | `DIFY_E2E_HITL_MULTI_NODE_APP_ID` | | | `DIFY_E2E_WS2_APP_ID` | Override secondary workspace app ID (EE) | +| `DIFY_E2E_REASONING_APP_ID` | separated-reasoning chatflow app ID (opt-in) | +| `DIFY_E2E_REASONING_PROVISION` | `1` → auto-provision `reasoning-chat.yml` | + +### Separated-mode reasoning suite (opt-in) + +`run-app-reasoning.e2e.ts` verifies the out-of-band `reasoning_chunk` channel +(PR #37460): `--think` surfaces the chain-of-thought to stderr framed as +``, the answer stays clean, and `-o json` persists it under +`metadata.reasoning`. It is **skipped** unless `DIFY_E2E_REASONING_APP_ID` +resolves, because it runs a real LLM node and needs: + +1. a chatflow whose LLM node uses `reasoning_format: separated`, and +1. a workspace with a default chat model configured. + +Point `DIFY_E2E_REASONING_APP_ID` at such an app, or set +`DIFY_E2E_REASONING_PROVISION=1` to import the `reasoning-chat.yml` fixture +(its system prompt forces a `` block, so any chat model triggers the +separated path — no dedicated reasoning model required). ## Running tests diff --git a/cli/test/e2e/fixtures/apps/reasoning-chat.yml b/cli/test/e2e/fixtures/apps/reasoning-chat.yml new file mode 100644 index 00000000000..ba24be57ef5 --- /dev/null +++ b/cli/test/e2e/fixtures/apps/reasoning-chat.yml @@ -0,0 +1,120 @@ +# Chatflow that exercises separated-mode reasoning (PR #37460): the LLM node sets +# reasoning_format=separated, so the server strips ... from the +# answer and streams the chain-of-thought on the out-of-band `reasoning_chunk` +# channel instead. The system prompt forces a block, so the separated +# path triggers with any chat model — no dedicated reasoning model required. +# +# NOTE: the LLM node leaves model.provider/name empty and relies on the target +# workspace's configured default chat model. The run-app-reasoning E2E suite is +# gated on DIFY_E2E_REASONING_APP_ID, so it is skipped unless a server with a +# working model is wired up. +app: + description: e2e-test reasoning (separated mode) + icon: 🧠 + icon_background: '#FFEAD5' + icon_type: emoji + mode: advanced-chat + name: reasoning-bot + use_icon_as_answer_icon: false +dependencies: [] +kind: app +version: 0.6.0 +workflow: + conversation_variables: [] + environment_variables: [] + features: + file_upload: {} + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - id: start-llm + source: '1755189262236' + sourceHandle: source + target: llm + targetHandle: target + - id: llm-answer + source: llm + sourceHandle: source + target: answer + targetHandle: target + nodes: + - data: + desc: '' + title: Start + type: start + variables: [] + id: '1755189262236' + position: + x: 80 + y: 282 + sourcePosition: right + targetPosition: left + type: custom + - data: + context: + enabled: false + variable_selector: [] + desc: '' + memory: + query_prompt_template: '{{#sys.query#}}' + window: + enabled: false + size: 10 + model: + completion_params: + temperature: 0.7 + mode: chat + name: '' + provider: '' + prompt_template: + - role: system + text: >- + You are a helpful assistant. Always reason step by step INSIDE a + single ... block first, then write the final + answer AFTER the closing tag. The final answer must not + contain any tags. + reasoning_format: separated + selected: false + title: LLM + type: llm + variables: [] + vision: + enabled: false + id: llm + position: + x: 380 + y: 282 + sourcePosition: right + targetPosition: left + type: custom + - data: + answer: '{{#llm.text#}}' + desc: '' + title: Answer + type: answer + variables: [] + id: answer + position: + x: 680 + y: 282 + sourcePosition: right + targetPosition: left + type: custom + viewport: + x: 0 + y: 0 + zoom: 1 + rag_pipeline_variables: [] diff --git a/cli/test/e2e/setup/env.ts b/cli/test/e2e/setup/env.ts index 3f71b514426..33f8361d46b 100644 --- a/cli/test/e2e/setup/env.ts +++ b/cli/test/e2e/setup/env.ts @@ -37,6 +37,9 @@ * DIFY_E2E_HITL_EXTERNAL_APP_ID * DIFY_E2E_HITL_SINGLE_ACTION_APP_ID * DIFY_E2E_HITL_MULTI_NODE_APP_ID + * DIFY_E2E_REASONING_APP_ID Override separated-reasoning chatflow app ID + * DIFY_E2E_REASONING_PROVISION=1 Opt in to auto-provisioning reasoning-chat.yml + * (needs a workspace default chat model) */ /** Supported edition values. */ @@ -74,6 +77,12 @@ export type E2EEnv = { fileAppId: string /** Chat app (advanced-chat) with a file input variable */ fileChatAppId: string + /** + * Chatflow whose LLM node uses reasoning_format=separated. Empty unless + * DIFY_E2E_REASONING_APP_ID is set or the fixture is auto-provisioned; the + * run-app-reasoning suite is skipped when empty. + */ + reasoningAppId: string /** * Secondary workspace ID — EE only ("auto_test1"). * Empty in CE mode (CE has a single workspace). @@ -118,6 +127,7 @@ export type E2ECapabilities = { workflowAppId: string fileAppId: string fileChatAppId: string + reasoningAppId: string hitlAppId: string hitlExternalAppId: string hitlSingleActionAppId: string @@ -171,6 +181,7 @@ export function loadE2EEnv(): E2EEnv { hitlMultiNodeAppId: process.env.DIFY_E2E_HITL_MULTI_NODE_APP_ID ?? '', fileAppId: process.env.DIFY_E2E_FILE_APP_ID ?? '', fileChatAppId: process.env.DIFY_E2E_FILE_CHAT_APP_ID ?? '', + reasoningAppId: process.env.DIFY_E2E_REASONING_APP_ID ?? '', ws2Id: process.env.DIFY_E2E_WS2_ID ?? '', ws2AppId: process.env.DIFY_E2E_WS2_APP_ID ?? '', email: process.env.DIFY_E2E_EMAIL!, @@ -206,6 +217,7 @@ export function resolveEnv(caps: E2ECapabilities | undefined): E2EEnv { workflowAppId: caps.workflowAppId || env.workflowAppId, fileAppId: caps.fileAppId || env.fileAppId, fileChatAppId: caps.fileChatAppId || env.fileChatAppId, + reasoningAppId: caps.reasoningAppId || env.reasoningAppId, hitlAppId: caps.hitlAppId || env.hitlAppId, hitlExternalAppId: caps.hitlExternalAppId || env.hitlExternalAppId, hitlSingleActionAppId: caps.hitlSingleActionAppId || env.hitlSingleActionAppId, diff --git a/cli/test/e2e/setup/global-setup.ts b/cli/test/e2e/setup/global-setup.ts index 20b23295acb..92cc29fc1f2 100644 --- a/cli/test/e2e/setup/global-setup.ts +++ b/cli/test/e2e/setup/global-setup.ts @@ -182,6 +182,7 @@ export async function setup(project: TestProject): Promise { workflowAppId: '', fileAppId: '', fileChatAppId: '', + reasoningAppId: '', hitlAppId: '', hitlExternalAppId: '', hitlSingleActionAppId: '', @@ -288,6 +289,7 @@ export async function setup(project: TestProject): Promise { workflowAppId: provisionedIds.DIFY_E2E_WORKFLOW_APP_ID || E.workflowAppId, fileAppId: provisionedIds.DIFY_E2E_FILE_APP_ID || E.fileAppId, fileChatAppId: provisionedIds.DIFY_E2E_FILE_CHAT_APP_ID || E.fileChatAppId, + reasoningAppId: provisionedIds.DIFY_E2E_REASONING_APP_ID || E.reasoningAppId, hitlAppId: provisionedIds.DIFY_E2E_HITL_APP_ID || E.hitlAppId, hitlExternalAppId: provisionedIds.DIFY_E2E_HITL_EXTERNAL_APP_ID || E.hitlExternalAppId, hitlSingleActionAppId: provisionedIds.DIFY_E2E_HITL_SINGLE_ACTION_APP_ID || E.hitlSingleActionAppId, @@ -503,6 +505,12 @@ async function provisionApps( ['hitl-single-action.yml', 'DIFY_E2E_HITL_SINGLE_ACTION_APP_ID', primaryWsId], ['hitl-multi-node.yml', 'DIFY_E2E_HITL_MULTI_NODE_APP_ID', primaryWsId], ['file-chat.yml', 'DIFY_E2E_FILE_CHAT_APP_ID', primaryWsId], + // reasoning-chat.yml runs a real LLM node, so it is opt-in: provisioning it + // requires the workspace to have a default chat model configured. Off by + // default to keep the shared bootstrap free of any model dependency. + ...(process.env.DIFY_E2E_REASONING_PROVISION === '1' + ? [['reasoning-chat.yml', 'DIFY_E2E_REASONING_APP_ID', primaryWsId] as [string, string, string]] + : []), ...(edition === 'ee' ? [['ws2-workflow.yml', 'DIFY_E2E_WS2_APP_ID', secondaryWsId] as [string, string, string]] : []), diff --git a/cli/test/e2e/suites/run/run-app-reasoning.e2e.ts b/cli/test/e2e/suites/run/run-app-reasoning.e2e.ts new file mode 100644 index 00000000000..fd5f12e0b57 --- /dev/null +++ b/cli/test/e2e/suites/run/run-app-reasoning.e2e.ts @@ -0,0 +1,91 @@ +/** + * E2E: difyctl run app — separated-mode reasoning (PR #37460) + * + * Exercises the out-of-band `reasoning_chunk` SSE channel against a real server. + * Requires a chatflow whose LLM node uses reasoning_format=separated AND a + * workspace with a configured chat model. The whole suite is skipped unless + * DIFY_E2E_REASONING_APP_ID resolves (set it directly, or provision the + * reasoning-chat.yml fixture with DIFY_E2E_REASONING_PROVISION=1). + * + * Verifies the client adaptation: + * - --think surfaces the separated reasoning to stderr, framed as + * - the answer (stdout) stays free of + * - -o json persists the reasoning under metadata.reasoning + * - without --think, reasoning stays hidden + */ + +import type { AuthFixture } from '../../helpers/cli.js' +import { afterEach, beforeEach, describe, expect, inject } from 'vitest' +import { assertExitCode, assertJson, assertStderrContains } from '../../helpers/assert.js' +import { registerConversation } from '../../helpers/cleanup-registry.js' +import { withAuthFixture } from '../../helpers/cli.js' +import { withRetry } from '../../helpers/retry.js' +import { optionalIt } from '../../helpers/skip.js' +import { resolveEnv } from '../../setup/env.js' + +// @ts-expect-error — see test/e2e/helpers/vitest-context.ts for explanation +const caps = inject('e2eCapabilities') as import('../../setup/env.js').E2ECapabilities +const E = resolveEnv(caps) + +// Skipped unless a separated-reasoning chatflow is wired up (needs a real model). +const reasoningIt = optionalIt(Boolean(E.reasoningAppId)) + +const QUERY = 'In one short sentence, why is the sky blue?' + +describe('E2E / difyctl run app — separated reasoning', () => { + let fx: AuthFixture + + beforeEach(async () => { + fx = await withAuthFixture(E) + }) + afterEach(async () => { + await fx.cleanup() + }) + + reasoningIt('[P1] --think --stream surfaces reasoning on stderr, clean answer on stdout', async () => { + const result = await withRetry( + () => fx.r(['run', 'app', E.reasoningAppId, QUERY, '--think', '--stream']), + { attempts: 3, delayMs: 1000 }, + ) + + assertExitCode(result, 0) + expect(result.stdout.trim().length).toBeGreaterThan(0) + // Separated mode keeps the answer free of ; reasoning is framed on stderr. + expect(result.stdout).not.toContain('') + assertStderrContains(result, '') + }) + + reasoningIt('[P1] --think -o json persists reasoning under metadata.reasoning', async () => { + const result = await withRetry( + () => fx.r(['run', 'app', E.reasoningAppId, QUERY, '--think', '-o', 'json']), + { attempts: 3, delayMs: 1000 }, + ) + + assertExitCode(result, 0) + const parsed = assertJson<{ + conversation_id?: string + answer: string + metadata?: { reasoning?: Record } + }>(result) + + if (parsed.conversation_id) + registerConversation(E.host, E.token, E.reasoningAppId, parsed.conversation_id) + + const reasoning = parsed.metadata?.reasoning ?? {} + expect(Object.keys(reasoning).length).toBeGreaterThan(0) + expect(Object.values(reasoning).join('').length).toBeGreaterThan(0) + // --think also echoes the separated reasoning to stderr. + assertStderrContains(result, '') + }) + + reasoningIt('[P1] without --think, reasoning stays hidden', async () => { + const result = await withRetry( + () => fx.r(['run', 'app', E.reasoningAppId, QUERY, '--stream']), + { attempts: 3, delayMs: 1000 }, + ) + + assertExitCode(result, 0) + expect(result.stdout.trim().length).toBeGreaterThan(0) + expect(result.stderr).not.toContain('') + }) +}) diff --git a/cli/test/fixtures/dify-mock/scenarios.ts b/cli/test/fixtures/dify-mock/scenarios.ts index 221ccbb6b81..2e0b74d3d54 100644 --- a/cli/test/fixtures/dify-mock/scenarios.ts +++ b/cli/test/fixtures/dify-mock/scenarios.ts @@ -15,6 +15,8 @@ export type Scenario | 'server-version-unsupported' | 'run-422-stale' | 'workflow-think' + | 'chat-reasoning' + | 'workflow-reasoning' | 'import-pending' | 'import-failed' diff --git a/cli/test/fixtures/dify-mock/server.ts b/cli/test/fixtures/dify-mock/server.ts index 4b119286dbc..766963cd0d5 100644 --- a/cli/test/fixtures/dify-mock/server.ts +++ b/cli/test/fixtures/dify-mock/server.ts @@ -370,6 +370,32 @@ export function buildApp(getScenario: () => Scenario, state?: MockState): Hono { ]) return new Response(thinkSse, { status: 200, headers: { 'content-type': 'text/event-stream' } }) } + if (scenario === 'chat-reasoning') { + // Separated mode: reasoning streams out-of-band on `reasoning_chunk` (nested + // under `data`), the answer stays free of , and the terminal reasoning + // is persisted into message_end metadata. + const reasoningSse = sseChunks([ + { event: 'reasoning_chunk', data: { data: { message_id: 'msg-1', reasoning: 'secret reasoning', node_id: 'llm-1', is_final: false } } }, + { event: 'reasoning_chunk', data: { data: { message_id: 'msg-1', reasoning: '', node_id: 'llm-1', is_final: true } } }, + { event: 'message', data: { message_id: 'msg-1', conversation_id: 'conv-1', mode: app.mode, answer: 'final answer' } }, + { event: 'message_end', data: { message_id: 'msg-1', conversation_id: 'conv-1', task_id: 'task-1', metadata: { reasoning: { 'llm-1': 'secret reasoning' } } } }, + ]) + return new Response(reasoningSse, { status: 200, headers: { 'content-type': 'text/event-stream' } }) + } + if (scenario === 'workflow-reasoning') { + // Separated mode in a workflow: reasoning streams out-of-band on + // `reasoning_chunk` (no message_id), outputs stay clean, and there is NO + // persisted metadata — the live deltas are the only source. + const wfReasoningSse = sseChunks([ + { event: 'workflow_started', data: { id: 'wf-run-1', workflow_id: 'wf-1' } }, + { event: 'node_started', data: { id: 'llm-1', title: 'LLM' } }, + { event: 'reasoning_chunk', data: { data: { reasoning: 'secret reasoning', node_id: 'llm-1', is_final: false } } }, + { event: 'reasoning_chunk', data: { data: { reasoning: '', node_id: 'llm-1', is_final: true } } }, + { event: 'node_finished', data: { id: 'llm-1', status: 'succeeded' } }, + { event: 'workflow_finished', data: { id: 'wf-run-1', workflow_id: 'wf-1', data: { id: 'wf-run-1', status: 'succeeded', outputs: { result: 'final answer' } } } }, + ]) + return new Response(wfReasoningSse, { status: 200, headers: { 'content-type': 'text/event-stream' } }) + } const sse = streamingRunResponse(app.mode, query, isAgent) return new Response(sse, { status: 200, headers: { 'content-type': 'text/event-stream' } }) }) diff --git a/dify-agent/src/dify_agent/layers/shell/layer.py b/dify-agent/src/dify_agent/layers/shell/layer.py index 5db17d68499..a8f46d628a6 100644 --- a/dify-agent/src/dify_agent/layers/shell/layer.py +++ b/dify-agent/src/dify_agent/layers/shell/layer.py @@ -18,9 +18,11 @@ side-effecting ``on_context_resume`` attempt fails after issuing shellctl jobs, Agenton still exits ``resource_context()`` but never transitions the layer to ``ACTIVE``. In that failed-enter path, normal suspend/delete hooks do not run, so the enter hook itself must perform best-effort business compensation before -re-raising the failure. Agent Stub env injection uses shellctl's native per-run -``env`` argument for user-visible ``shell.run`` and for trusted server-owned -fixed scripts executed through ``run_remote_script()``. +re-raising the failure. Agent Soul shell env is injected into user-visible +commands and CLI bootstrap commands without persisting a workspace env file. +Agent Stub env injection uses shellctl's native per-run ``env`` argument for +user-visible ``shell.run`` and for trusted server-owned fixed scripts executed +through ``run_remote_script()``. """ from __future__ import annotations @@ -475,7 +477,7 @@ class DifyShellLayer(PydanticAILayer[DifyShellLayerDeps, object, DifyShellLayerC try: client = self._require_client() result = await client.run( - _wrap_user_script(script), + _wrap_user_script(script, self.config), cwd=self._require_workspace_cwd(), env=self._build_user_shell_run_env(), timeout=timeout, @@ -536,9 +538,9 @@ class DifyShellLayer(PydanticAILayer[DifyShellLayerDeps, object, DifyShellLayerC and optional Agent Stub env injection. Unlike model-visible ``shell.run``, this server-owned boundary does not - source ``.dify/env.sh``. That file is user-controlled shell config, so - sourcing it here would let sandbox code clobber trusted Agent Stub env - values before ``dify-agent file upload`` executes. + inject Agent Soul shell env. Keeping the user-controlled shell env out + of this path prevents sandbox code from clobbering trusted Agent Stub + env values before ``dify-agent file upload`` executes. """ env = None if inject_agent_stub_env: @@ -833,16 +835,18 @@ def _workspace_cwd(session_id: str) -> str: def _workspace_bootstrap_script(config: DifyShellLayerConfig) -> str: - """Return the workspace bootstrap script for env + CLI tool declarations.""" - has_bootstrap = bool(config.env or config.secret_refs or config.cli_tools or config.sandbox is not None) - if not has_bootstrap: + """Return the workspace bootstrap script for CLI tool declarations.""" + install_commands = [command for tool in config.cli_tools for command in tool.install_commands] + if not install_commands: return "" - lines: list[str] = [ - "set -eu", - 'mkdir -p ".dify"', - "cat > \".dify/env.sh\" <<'DIFY_ENV_EOF'", - ] + lines: list[str] = ["set -eu", *_shell_config_export_lines(config), *install_commands] + return "\n".join(lines) + + +def _shell_config_export_lines(config: DifyShellLayerConfig) -> list[str]: + """Return ephemeral Agent Soul shell exports for one shellctl command.""" + lines: list[str] = [] for env_var in config.env: lines.append(f"export {env_var.name}={_shquote(env_var.value)}") for secret_ref in config.secret_refs: @@ -860,32 +864,15 @@ def _workspace_bootstrap_script(config: DifyShellLayerConfig) -> str: if config.sandbox.config: sandbox_config = json.dumps(config.sandbox.config, ensure_ascii=True, sort_keys=True) lines.append(f"export DIFY_SANDBOX_CONFIG_JSON={_shquote(sandbox_config)}") - lines.extend( - [ - "DIFY_ENV_EOF", - 'chmod 600 ".dify/env.sh"', - '. ".dify/env.sh"', - ] - ) - for tool in config.cli_tools: - for command in tool.install_commands: - lines.append(command) - return "\n".join(lines) + return lines -def _wrap_user_script(script: str) -> str: - """Source Agent Soul env before executing a model-requested shell command.""" - # TODO: refactor - return "\n".join( - [ - 'if [ -f ".dify/env.sh" ]; then', - " set -a", - ' . ".dify/env.sh"', - " set +a", - "fi", - script, - ] - ) +def _wrap_user_script(script: str, config: DifyShellLayerConfig) -> str: + """Inject Agent Soul env before executing a model-requested shell command.""" + lines = _shell_config_export_lines(config) + if not lines: + return script + return "\n".join([*lines, script]) def _workspace_mkdir_script(*, session_id: str) -> str: diff --git a/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py b/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py index 30352d87c5e..c7d2599b63c 100644 --- a/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py +++ b/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py @@ -3,6 +3,7 @@ from collections.abc import Callable, Mapping import secrets import time from dataclasses import dataclass +from typing import cast import pytest @@ -454,7 +455,6 @@ def test_shell_layer_create_bootstraps_agent_soul_shell_config(monkeypatch: pyte assert 'export GITHUB_TOKEN="${GITHUB_TOKEN:-}"' in script assert "export DIFY_SANDBOX_PROVIDER='independent'" in script assert "export DIFY_SANDBOX_CONFIG_JSON='{\"cpu\": 2}'" in script - assert '. ".dify/env.sh"' in script assert "apt-get install -y ripgrep" in script return _job_result("bootstrap-job", status=JobStatusName.EXITED, done=True, exit_code=0) @@ -489,10 +489,60 @@ def test_shell_layer_create_bootstraps_agent_soul_shell_config(monkeypatch: pyte assert layer.runtime_state.job_ids == ["mkdir-job", "bootstrap-job"] +def test_shell_layer_injects_agent_soul_env_without_workspace_env_file(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(time, "time", lambda: 0xABC12) + + def token_hex(_nbytes: int) -> str: + return "ff" + + monkeypatch.setattr(secrets, "token_hex", token_hex) + + def run_handler(script: str, cwd: str | None, env: Mapping[str, str] | None, timeout: float) -> JobResult: + del timeout + assert env is None + if cwd is None: + return _job_result("mkdir-job", status=JobStatusName.EXITED, done=True, exit_code=0) + + assert cwd == "~/workspace/abc12ff" + assert "export PROJECT_NAME='demo project'" in script + assert 'export OPENAI_API_KEY="${OPENAI_API_KEY:-}"' in script + assert "export DIFY_SANDBOX_PROVIDER='independent'" in script + assert "export DIFY_SANDBOX_CONFIG_JSON='{\"cpu\": 2}'" in script + assert script.endswith("\npwd") + return _job_result("user-job", status=JobStatusName.EXITED, done=True, exit_code=0) + + client = FakeShellctlClient(run_handler=run_handler) + layer = _shell_layer( + client_factory=lambda _entrypoint: client, + config=DifyShellLayerConfig( + env=[DifyShellEnvVarConfig(name="PROJECT_NAME", value="demo project")], + secret_refs=[DifyShellSecretRefConfig(name="OPENAI_API_KEY", ref="secret-1")], + sandbox=DifyShellSandboxConfig(provider="independent", config={"cpu": 2}), + ), + ) + tools = {tool.name: tool for tool in layer.tools} + + async def scenario() -> None: + async with layer.resource_context(): + await layer.on_context_create() + run_result = cast( + Mapping[str, object], + await tools["shell_run"].function_schema.call( + {"script": "pwd"}, + None, # pyright: ignore[reportArgumentType] + ), + ) + assert run_result["job_id"] == "user-job" + + asyncio.run(scenario()) + + assert [call.cwd for call in client.run_calls] == [None, "~/workspace/abc12ff"] + assert layer.runtime_state.job_ids == ["mkdir-job", "user-job"] + + def test_shell_layer_tools_map_inputs_to_shellctl_calls_and_maintain_offsets() -> None: def run_handler(script: str, cwd: str | None, env: Mapping[str, str] | None, timeout: float) -> JobResult: - assert script.endswith("\npwd") - assert '. ".dify/env.sh"' in script + assert script == "pwd" assert cwd == "~/workspace/abc12ff" assert env is None assert timeout == 2.5 @@ -608,8 +658,7 @@ def test_shell_layer_tools_map_inputs_to_shellctl_calls_and_maintain_offsets() - def test_shell_layer_injects_agent_stub_env_only_for_user_visible_shell_run() -> None: def run_handler(script: str, cwd: str | None, env: Mapping[str, str] | None, timeout: float) -> JobResult: del cwd, timeout - if script.endswith("\npwd"): - assert '. ".dify/env.sh"' in script + if script == "pwd": assert env is not None return _job_result("user-job", status=JobStatusName.EXITED, done=True, exit_code=0) assert env is None @@ -639,8 +688,8 @@ def test_shell_layer_injects_agent_stub_env_only_for_user_visible_shell_run() -> asyncio.run(scenario()) - user_run_call = next(call for call in client.run_calls if call.script.endswith("\npwd")) - internal_run_calls = [call for call in client.run_calls if not call.script.endswith("\npwd")] + user_run_call = next(call for call in client.run_calls if call.script == "pwd") + internal_run_calls = [call for call in client.run_calls if call.script != "pwd"] assert user_run_call.env == { AGENT_STUB_API_BASE_URL_ENV_VAR: "https://agent.example.com/agent-stub", diff --git a/docker/.env.example b/docker/.env.example index 78ebc3e4df1..9646eeeb735 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -41,6 +41,9 @@ FILES_ACCESS_TIMEOUT=300 # Remove `collaboration` from COMPOSE_PROFILES to stop the dedicated websocket service. ENABLE_COLLABORATION_MODE=true +# Learn app feature toggle +ENABLE_LEARN_APP=true + # Logging and server workers LOG_LEVEL=INFO LOG_OUTPUT_FORMAT=text diff --git a/docker/envs/core-services/shared.env.example b/docker/envs/core-services/shared.env.example index 26274fe87d2..391dba2e21a 100644 --- a/docker/envs/core-services/shared.env.example +++ b/docker/envs/core-services/shared.env.example @@ -18,6 +18,9 @@ MIGRATION_ENABLED=true FILES_ACCESS_TIMEOUT=300 # Remove `collaboration` from COMPOSE_PROFILES to stop the dedicated websocket service. ENABLE_COLLABORATION_MODE=true + +# Learn app feature toggle +ENABLE_LEARN_APP=true CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1 CELERY_TASK_ANNOTATIONS=null AZURE_BLOB_ACCOUNT_URL=https://.blob.core.windows.net diff --git a/eslint-suppressions.json b/eslint-suppressions.json index a9975b4476e..e00a1c2a1a3 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -121,11 +121,6 @@ "count": 3 } }, - "web/__tests__/navigation-utils.test.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, "web/__tests__/plugin-tool-workflow-error.test.tsx": { "ts/no-explicit-any": { "count": 2 @@ -343,14 +338,6 @@ "count": 4 } }, - "web/app/components/app-sidebar/app-sidebar-dropdown.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -487,11 +474,6 @@ "count": 1 } }, - "web/app/components/app/configuration/base/var-highlight/index.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "web/app/components/app/configuration/config-prompt/__tests__/index.spec.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -559,14 +541,6 @@ "count": 1 } }, - "web/app/components/app/configuration/config-var/select-type-item/index.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/app/configuration/config-var/select-var-type.tsx": { "ts/no-explicit-any": { "count": 1 @@ -616,17 +590,6 @@ "count": 4 } }, - "web/app/components/app/configuration/config/assistant-type-picker/index.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/app/configuration/config/automatic/get-automatic-res.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -808,11 +771,6 @@ "count": 1 } }, - "web/app/components/app/configuration/prompt-value-panel/utils.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/app/create-app-dialog/app-list/index.tsx": { "no-restricted-imports": { "count": 1 @@ -1001,11 +959,6 @@ "count": 1 } }, - "web/app/components/apps/new-app-card.tsx": { - "react/set-state-in-effect": { - "count": 3 - } - }, "web/app/components/apps/starred-app-card.tsx": { "jsx-a11y/no-noninteractive-element-to-interactive-role": { "count": 1 @@ -1110,34 +1063,6 @@ "count": 1 } }, - "web/app/components/base/block-input/index.stories.tsx": { - "no-console": { - "count": 2 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/base/block-input/index.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - }, - "react-refresh/only-export-components": { - "count": 1 - }, - "react/no-nested-component-definitions": { - "count": 1 - }, - "react/set-state-in-effect": { - "count": 1 - }, - "react/static-components": { - "count": 2 - } - }, "web/app/components/base/carousel/index.tsx": { "react-hooks-extra/no-direct-set-state-in-use-effect": { "count": 1 @@ -1429,7 +1354,7 @@ }, "web/app/components/base/error-boundary/index.tsx": { "react-refresh/only-export-components": { - "count": 3 + "count": 1 }, "react/jsx-no-key-after-spread": { "count": 1 @@ -1622,17 +1547,6 @@ "count": 2 } }, - "web/app/components/base/form/components/field/mixed-variable-text-input/placeholder.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 2 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 2 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/base/form/components/field/variable-selector.tsx": { "no-console": { "count": 1 @@ -1648,22 +1562,12 @@ "count": 1 } }, - "web/app/components/base/form/form-scenarios/base/index.tsx": { - "react/static-components": { - "count": 2 - } - }, "web/app/components/base/form/form-scenarios/base/types.ts": { "erasable-syntax-only/enums": { "count": 1 }, "ts/no-explicit-any": { - "count": 3 - } - }, - "web/app/components/base/form/form-scenarios/demo/index.tsx": { - "no-console": { - "count": 2 + "count": 1 } }, "web/app/components/base/form/form-scenarios/input-field/__tests__/field.spec.tsx": { @@ -1684,24 +1588,6 @@ "count": 2 } }, - "web/app/components/base/form/form-scenarios/node-panel/__tests__/field.spec.tsx": { - "react/static-components": { - "count": 2 - } - }, - "web/app/components/base/form/form-scenarios/node-panel/field.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/base/form/form-scenarios/node-panel/types.ts": { - "erasable-syntax-only/enums": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 2 - } - }, "web/app/components/base/form/hooks/index.ts": { "no-barrel-files/no-barrel-files": { "count": 3 @@ -1807,7 +1693,7 @@ }, "web/app/components/base/icons/src/vender/line/arrows/index.ts": { "no-barrel-files/no-barrel-files": { - "count": 6 + "count": 5 } }, "web/app/components/base/icons/src/vender/line/communication/index.ts": { @@ -1832,7 +1718,7 @@ }, "web/app/components/base/icons/src/vender/line/files/index.ts": { "no-barrel-files/no-barrel-files": { - "count": 6 + "count": 4 } }, "web/app/components/base/icons/src/vender/line/financeAndECommerce/index.ts": { @@ -1842,7 +1728,7 @@ }, "web/app/components/base/icons/src/vender/line/general/index.ts": { "no-barrel-files/no-barrel-files": { - "count": 11 + "count": 10 } }, "web/app/components/base/icons/src/vender/line/images/index.ts": { @@ -1850,11 +1736,6 @@ "count": 1 } }, - "web/app/components/base/icons/src/vender/line/layout/index.ts": { - "no-barrel-files/no-barrel-files": { - "count": 1 - } - }, "web/app/components/base/icons/src/vender/line/mediaAndDevices/index.ts": { "no-barrel-files/no-barrel-files": { "count": 2 @@ -1917,7 +1798,7 @@ }, "web/app/components/base/icons/src/vender/solid/education/index.ts": { "no-barrel-files/no-barrel-files": { - "count": 3 + "count": 2 } }, "web/app/components/base/icons/src/vender/solid/files/index.ts": { @@ -1986,17 +1867,6 @@ "count": 3 } }, - "web/app/components/base/image-uploader/audio-preview.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/media-has-caption": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/base/image-uploader/hooks.ts": { "ts/no-explicit-any": { "count": 4 @@ -2034,17 +1904,6 @@ "count": 2 } }, - "web/app/components/base/image-uploader/video-preview.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/media-has-caption": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/base/inline-delete-confirm/index.stories.tsx": { "no-console": { "count": 2 @@ -2133,9 +1992,6 @@ } }, "web/app/components/base/markdown-blocks/think-block.tsx": { - "react/set-state-in-effect": { - "count": 1 - }, "ts/no-explicit-any": { "count": 4 } @@ -2597,27 +2453,6 @@ "count": 4 } }, - "web/app/components/base/with-input-validation/index.stories.tsx": { - "jsx-a11y/aria-role": { - "count": 7 - }, - "no-console": { - "count": 1 - } - }, - "web/app/components/base/with-input-validation/index.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/billing/header-billing-btn/index.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/billing/plan/assets/index.tsx": { "no-barrel-files/no-barrel-files": { "count": 4 @@ -2742,11 +2577,6 @@ "count": 3 } }, - "web/app/components/datasets/common/retrieval-method-info/index.tsx": { - "react-refresh/only-export-components": { - "count": 1 - } - }, "web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/header.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -2895,14 +2725,6 @@ "count": 3 } }, - "web/app/components/datasets/create/step-two/preview-item/index.tsx": { - "erasable-syntax-only/enums": { - "count": 1 - }, - "react-refresh/only-export-components": { - "count": 1 - } - }, "web/app/components/datasets/create/website/base/crawled-result-item.tsx": { "jsx-a11y/label-has-associated-control": { "count": 1 @@ -3229,19 +3051,6 @@ "count": 1 } }, - "web/app/components/datasets/documents/detail/metadata/components/metadata-field-list.tsx": { - "ts/no-non-null-asserted-optional-chain": { - "count": 1 - } - }, - "web/app/components/datasets/documents/detail/metadata/hooks/use-metadata-state.ts": { - "react-hooks-extra/no-direct-set-state-in-use-effect": { - "count": 6 - }, - "react/set-state-in-effect": { - "count": 6 - } - }, "web/app/components/datasets/documents/detail/metadata/index.tsx": { "no-barrel-files/no-barrel-files": { "count": 1 @@ -3367,14 +3176,6 @@ "count": 1 } }, - "web/app/components/datasets/list/dataset-card/operation-item.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/datasets/metadata/edit-metadata-batch/add-row.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -3473,7 +3274,7 @@ }, "web/app/components/develop/code.tsx": { "ts/no-explicit-any": { - "count": 7 + "count": 6 } }, "web/app/components/develop/doc.tsx": { @@ -3485,9 +3286,6 @@ "jsx-a11y/no-redundant-roles": { "count": 1 }, - "ts/no-empty-object-type": { - "count": 1 - }, "ts/no-explicit-any": { "count": 2 } @@ -3658,24 +3456,6 @@ "count": 2 } }, - "web/app/components/header/account-setting/key-validator/Operate.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 4 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 4 - } - }, - "web/app/components/header/account-setting/key-validator/declarations.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/header/account-setting/language-page/__tests__/index.spec.tsx": { - "jsx-a11y/role-has-required-aria-props": { - "count": 1 - } - }, "web/app/components/header/account-setting/members-page/edit-workspace-modal/index.tsx": { "jsx-a11y/no-autofocus": { "count": 1 @@ -3908,14 +3688,6 @@ "count": 1 } }, - "web/app/components/header/nav/index.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/main-nav/components/workspace-switcher.tsx": { "jsx-a11y/no-autofocus": { "count": 1 @@ -4052,11 +3824,6 @@ "count": 2 } }, - "web/app/components/plugins/plugin-auth/utils.ts": { - "ts/no-explicit-any": { - "count": 2 - } - }, "web/app/components/plugins/plugin-detail-panel/__tests__/operation-dropdown.spec.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -4220,17 +3987,9 @@ "count": 1 } }, - "web/app/components/plugins/plugin-detail-panel/tool-selector/__tests__/index.spec.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 5 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 5 - } - }, "web/app/components/plugins/plugin-detail-panel/tool-selector/components/index.ts": { "no-barrel-files/no-barrel-files": { - "count": 7 + "count": 6 } }, "web/app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx": { @@ -4260,11 +4019,6 @@ "count": 3 } }, - "web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/index.ts": { - "no-barrel-files/no-barrel-files": { - "count": 2 - } - }, "web/app/components/plugins/plugin-detail-panel/tool-selector/index.tsx": { "no-restricted-imports": { "count": 1 @@ -4305,14 +4059,6 @@ "count": 1 } }, - "web/app/components/plugins/plugin-page/filter-management/__tests__/index.spec.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/plugins/plugin-page/filter-management/category-filter.tsx": { "no-restricted-imports": { "count": 1 @@ -4637,11 +4383,6 @@ "count": 1 } }, - "web/app/components/share/utils.ts": { - "ts/no-explicit-any": { - "count": 2 - } - }, "web/app/components/snippet-list/components/snippet-card.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -4732,11 +4473,6 @@ "count": 1 } }, - "web/app/components/tools/provider/detail.tsx": { - "jsx-a11y/anchor-has-content": { - "count": 1 - } - }, "web/app/components/tools/provider/tool-item.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -4979,11 +4715,6 @@ "count": 4 } }, - "web/app/components/workflow/block-selector/use-check-vertical-scrollbar.ts": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/workflow/block-selector/use-sticky-scroll.ts": { "erasable-syntax-only/enums": { "count": 1 @@ -5331,10 +5062,10 @@ }, "web/app/components/workflow/nodes/_base/components/layout/index.tsx": { "no-barrel-files/no-barrel-files": { - "count": 7 + "count": 6 }, "react-refresh/only-export-components": { - "count": 7 + "count": 6 } }, "web/app/components/workflow/nodes/_base/components/mcp-tool-availability.tsx": { @@ -5350,30 +5081,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/_base/components/mixed-variable-text-input/__tests__/placeholder.spec.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/_base/components/mixed-variable-text-input/index.tsx": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/_base/components/mixed-variable-text-input/placeholder.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 2 - }, - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/workflow/nodes/_base/components/next-step/operator.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 2 @@ -5429,14 +5136,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/_base/components/support-var-input/index.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/workflow/nodes/_base/components/switch-plugin-version.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -5635,11 +5334,6 @@ "count": 7 } }, - "web/app/components/workflow/nodes/agent/use-single-run-form-params.ts": { - "ts/no-explicit-any": { - "count": 3 - } - }, "web/app/components/workflow/nodes/answer/default.ts": { "ts/no-explicit-any": { "count": 1 @@ -5683,20 +5377,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/code/dependency-picker.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-autofocus": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - }, - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/code/types.ts": { "erasable-syntax-only/enums": { "count": 1 @@ -5747,9 +5427,6 @@ "erasable-syntax-only/enums": { "count": 1 }, - "no-barrel-files/no-barrel-files": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -5792,14 +5469,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/http/components/key-value/bulk-edit/index.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/workflow/nodes/http/components/key-value/key-value-edit/index.tsx": { "ts/no-explicit-any": { "count": 2 @@ -6393,11 +6062,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/tool/components/input-var-list.tsx": { - "ts/no-explicit-any": { - "count": 7 - } - }, "web/app/components/workflow/nodes/tool/components/mixed-variable-text-input/index.tsx": { "ts/no-explicit-any": { "count": 1 @@ -6495,11 +6159,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/trigger-plugin/utils/form-helpers.ts": { - "ts/no-explicit-any": { - "count": 7 - } - }, "web/app/components/workflow/nodes/trigger-schedule/default.ts": { "regexp/no-unused-capturing-group": { "count": 2 @@ -6851,14 +6510,6 @@ "count": 1 } }, - "web/app/components/workflow/run/__tests__/loop-result-panel.spec.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/workflow/run/__tests__/special-result-panel.spec.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -6946,14 +6597,6 @@ "count": 2 } }, - "web/app/components/workflow/run/loop-result-panel.tsx": { - "jsx-a11y/click-events-have-key-events": { - "count": 1 - }, - "jsx-a11y/no-static-element-interactions": { - "count": 1 - } - }, "web/app/components/workflow/run/node.tsx": { "jsx-a11y/click-events-have-key-events": { "count": 1 @@ -7030,26 +6673,11 @@ "count": 11 } }, - "web/app/components/workflow/run/utils/format-log/graph-to-log-struct.ts": { - "ts/no-explicit-any": { - "count": 7 - } - }, "web/app/components/workflow/run/utils/format-log/index.ts": { "ts/no-explicit-any": { "count": 2 } }, - "web/app/components/workflow/run/utils/format-log/iteration/index.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, - "web/app/components/workflow/run/utils/format-log/loop/index.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/workflow/run/utils/format-log/parallel/index.ts": { "no-console": { "count": 4 @@ -7099,11 +6727,6 @@ "count": 1 } }, - "web/app/components/workflow/utils/debug.ts": { - "ts/no-explicit-any": { - "count": 1 - } - }, "web/app/components/workflow/utils/index.ts": { "no-barrel-files/no-barrel-files": { "count": 10 @@ -7492,11 +7115,6 @@ "count": 1 } }, - "web/service/__tests__/use-snippet-workflows.spec.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/service/__tests__/use-tools.spec.tsx": { "no-restricted-imports": { "count": 1 @@ -7601,7 +7219,7 @@ "count": 1 }, "ts/no-explicit-any": { - "count": 26 + "count": 9 } }, "web/service/datasets.ts": { @@ -7609,7 +7227,7 @@ "count": 1 }, "ts/no-explicit-any": { - "count": 6 + "count": 5 } }, "web/service/debug.ts": { @@ -7932,11 +7550,6 @@ "count": 6 } }, - "web/utils/navigation.spec.ts": { - "ts/no-explicit-any": { - "count": 4 - } - }, "web/utils/tool-call.spec.ts": { "ts/no-explicit-any": { "count": 1 diff --git a/packages/contracts/generated/api/console/agent/types.gen.ts b/packages/contracts/generated/api/console/agent/types.gen.ts index 6d647d6add4..3a5dd864c0d 100644 --- a/packages/contracts/generated/api/console/agent/types.gen.ts +++ b/packages/contracts/generated/api/console/agent/types.gen.ts @@ -135,9 +135,14 @@ export type ComposerSavePayload = { agent_soul?: AgentSoulConfig | null binding?: ComposerBindingPayload | null client_revision_id?: string | null + description?: string | null + icon?: string | null + icon_background?: string | null + icon_type?: AgentIconType | null idempotency_key?: string | null new_agent_name?: string | null node_job?: WorkflowNodeJobConfig | null + role?: string | null save_strategy: ComposerSaveStrategy soul_lock?: ComposerSoulLockPayload variant: ComposerVariant @@ -537,6 +542,8 @@ export type ComposerBindingPayload = { current_snapshot_id?: string | null } +export type AgentIconType = 'emoji' | 'image' | 'link' + export type WorkflowNodeJobConfig = { declared_outputs?: Array human_contacts?: Array @@ -877,8 +884,6 @@ export type LlmMode = 'chat' | 'completion' export type AgentKind = 'dify_agent' -export type AgentIconType = 'emoji' | 'image' | 'link' - export type AgentPublishedReferenceResponse = { app_icon?: string | null app_icon_background?: string | null diff --git a/packages/contracts/generated/api/console/agent/zod.gen.ts b/packages/contracts/generated/api/console/agent/zod.gen.ts index 237be1a319d..ce674c132ce 100644 --- a/packages/contracts/generated/api/console/agent/zod.gen.ts +++ b/packages/contracts/generated/api/console/agent/zod.gen.ts @@ -283,6 +283,13 @@ export const zComposerBindingPayload = z.object({ current_snapshot_id: z.string().nullish(), }) +/** + * AgentIconType + * + * Supported icon storage formats for Agent roster entries. + */ +export const zAgentIconType = z.enum(['emoji', 'image', 'link']) + /** * ComposerSoulLockPayload */ @@ -831,13 +838,6 @@ export const zAgentAppDetailWithSite = z.object({ */ export const zAgentKind = z.enum(['dify_agent']) -/** - * AgentIconType - * - * Supported icon storage formats for Agent roster entries. - */ -export const zAgentIconType = z.enum(['emoji', 'image', 'link']) - /** * AgentPublishedReferenceResponse */ @@ -1877,9 +1877,14 @@ export const zComposerSavePayload = z.object({ agent_soul: zAgentSoulConfig.nullish(), binding: zComposerBindingPayload.nullish(), client_revision_id: z.string().nullish(), + description: z.string().nullish(), + icon: z.string().max(255).nullish(), + icon_background: z.string().max(255).nullish(), + icon_type: zAgentIconType.nullish(), idempotency_key: z.string().nullish(), new_agent_name: z.string().min(1).max(255).nullish(), node_job: zWorkflowNodeJobConfig.nullish(), + role: z.string().max(255).nullish(), save_strategy: zComposerSaveStrategy, soul_lock: zComposerSoulLockPayload.optional(), variant: zComposerVariant, diff --git a/packages/contracts/generated/api/console/apps/orpc.gen.ts b/packages/contracts/generated/api/console/apps/orpc.gen.ts index 7a93572885e..ea72df28458 100644 --- a/packages/contracts/generated/api/console/apps/orpc.gen.ts +++ b/packages/contracts/generated/api/console/apps/orpc.gen.ts @@ -392,6 +392,9 @@ import { zPostAppsByAppIdWorkflowsDraftLoopNodesByNodeIdRunBody, zPostAppsByAppIdWorkflowsDraftLoopNodesByNodeIdRunPath, zPostAppsByAppIdWorkflowsDraftLoopNodesByNodeIdRunResponse, + zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterBody, + zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterPath, + zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponse, zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactBody, zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactPath, zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactResponse, @@ -3479,6 +3482,26 @@ export const candidates = { } export const post51 = oc + .route({ + inputStructure: 'detailed', + method: 'POST', + operationId: 'postAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRoster', + path: '/apps/{app_id}/workflows/draft/nodes/{node_id}/agent-composer/copy-from-roster', + tags: ['console'], + }) + .input( + z.object({ + body: zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterBody, + params: zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterPath, + }), + ) + .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponse) + +export const copyFromRoster = { + post: post51, +} + +export const post52 = oc .route({ inputStructure: 'detailed', method: 'POST', @@ -3495,10 +3518,10 @@ export const post51 = oc .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactResponse) export const impact = { - post: post51, + post: post52, } -export const post52 = oc +export const post53 = oc .route({ inputStructure: 'detailed', method: 'POST', @@ -3515,10 +3538,10 @@ export const post52 = oc .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerSaveToRosterResponse) export const saveToRoster = { - post: post52, + post: post53, } -export const post53 = oc +export const post54 = oc .route({ inputStructure: 'detailed', method: 'POST', @@ -3535,7 +3558,7 @@ export const post53 = oc .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerValidateResponse) export const validate = { - post: post53, + post: post54, } export const get62 = oc @@ -3569,6 +3592,7 @@ export const agentComposer = { get: get62, put: put4, candidates, + copyFromRoster, impact, saveToRoster, validate, @@ -3598,7 +3622,7 @@ export const lastRun = { * * Run draft workflow node */ -export const post54 = oc +export const post55 = oc .route({ description: 'Run draft workflow node', inputStructure: 'detailed', @@ -3617,7 +3641,7 @@ export const post54 = oc .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdRunResponse) export const run8 = { - post: post54, + post: post55, } /** @@ -3625,7 +3649,7 @@ export const run8 = { * * Poll for trigger events and execute single node when event arrives */ -export const post55 = oc +export const post56 = oc .route({ description: 'Poll for trigger events and execute single node when event arrives', inputStructure: 'detailed', @@ -3639,7 +3663,7 @@ export const post55 = oc .output(zPostAppsByAppIdWorkflowsDraftNodesByNodeIdTriggerRunResponse) export const run9 = { - post: post55, + post: post56, } export const trigger = { @@ -3699,7 +3723,7 @@ export const nodes7 = { * * Run draft workflow */ -export const post56 = oc +export const post57 = oc .route({ description: 'Run draft workflow', inputStructure: 'detailed', @@ -3718,7 +3742,7 @@ export const post56 = oc .output(zPostAppsByAppIdWorkflowsDraftRunResponse) export const run10 = { - post: post56, + post: post57, } /** @@ -3840,7 +3864,7 @@ export const systemVariables = { * * Poll for trigger events and execute full workflow when event arrives */ -export const post57 = oc +export const post58 = oc .route({ description: 'Poll for trigger events and execute full workflow when event arrives', inputStructure: 'detailed', @@ -3859,7 +3883,7 @@ export const post57 = oc .output(zPostAppsByAppIdWorkflowsDraftTriggerRunResponse) export const run11 = { - post: post57, + post: post58, } /** @@ -3867,7 +3891,7 @@ export const run11 = { * * Full workflow debug when the start node is a trigger */ -export const post58 = oc +export const post59 = oc .route({ description: 'Full workflow debug when the start node is a trigger', inputStructure: 'detailed', @@ -3886,7 +3910,7 @@ export const post58 = oc .output(zPostAppsByAppIdWorkflowsDraftTriggerRunAllResponse) export const runAll = { - post: post58, + post: post59, } export const trigger2 = { @@ -4039,7 +4063,7 @@ export const get72 = oc * * Sync draft workflow configuration */ -export const post59 = oc +export const post60 = oc .route({ description: 'Sync draft workflow configuration', inputStructure: 'detailed', @@ -4059,7 +4083,7 @@ export const post59 = oc export const draft2 = { get: get72, - post: post59, + post: post60, conversationVariables: conversationVariables2, environmentVariables, features, @@ -4095,7 +4119,7 @@ export const get73 = oc /** * Publish workflow */ -export const post60 = oc +export const post61 = oc .route({ inputStructure: 'detailed', method: 'POST', @@ -4114,7 +4138,7 @@ export const post60 = oc export const publish = { get: get73, - post: post60, + post: post61, } /** @@ -4251,7 +4275,7 @@ export const triggers2 = { /** * Restore a published workflow version into the draft workflow */ -export const post61 = oc +export const post62 = oc .route({ description: 'Restore a published workflow version into the draft workflow', inputStructure: 'detailed', @@ -4264,7 +4288,7 @@ export const post61 = oc .output(zPostAppsByAppIdWorkflowsByWorkflowIdRestoreResponse) export const restore = { - post: post61, + post: post62, } /** @@ -4489,7 +4513,7 @@ export const get81 = oc * * Create a new API key for an app */ -export const post62 = oc +export const post63 = oc .route({ description: 'Create a new API key for an app', inputStructure: 'detailed', @@ -4505,7 +4529,7 @@ export const post62 = oc export const apiKeys = { get: get81, - post: post62, + post: post63, byApiKeyId, } @@ -4563,7 +4587,7 @@ export const get83 = oc * * Create a new application */ -export const post63 = oc +export const post64 = oc .route({ description: 'Create a new application', inputStructure: 'detailed', @@ -4579,7 +4603,7 @@ export const post63 = oc export const apps = { get: get83, - post: post63, + post: post64, imports, starred, workflows, diff --git a/packages/contracts/generated/api/console/apps/types.gen.ts b/packages/contracts/generated/api/console/apps/types.gen.ts index 57dc55e1274..cc8268df3b0 100644 --- a/packages/contracts/generated/api/console/apps/types.gen.ts +++ b/packages/contracts/generated/api/console/apps/types.gen.ts @@ -986,9 +986,14 @@ export type ComposerSavePayload = { agent_soul?: AgentSoulConfig | null binding?: ComposerBindingPayload | null client_revision_id?: string | null + description?: string | null + icon?: string | null + icon_background?: string | null + icon_type?: AgentIconType | null idempotency_key?: string | null new_agent_name?: string | null node_job?: WorkflowNodeJobConfig | null + role?: string | null save_strategy: ComposerSaveStrategy soul_lock?: ComposerSoulLockPayload variant: ComposerVariant @@ -1003,6 +1008,12 @@ export type AgentComposerCandidatesResponse = { variant: ComposerVariant } +export type WorkflowComposerCopyFromRosterPayload = { + idempotency_key?: string | null + source_agent_id: string + source_snapshot_id?: string | null +} + export type AgentComposerImpactResponse = { bindings?: Array current_snapshot_id?: string | null @@ -1874,6 +1885,8 @@ export type ComposerBindingPayload = { current_snapshot_id?: string | null } +export type AgentIconType = 'emoji' | 'image' | 'link' + export type ComposerSoulLockPayload = { locked?: boolean unlocked_from_version_id?: string | null @@ -5416,6 +5429,23 @@ export type GetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesResp export type GetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesResponse = GetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesResponses[keyof GetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesResponses] +export type PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterData = { + body: WorkflowComposerCopyFromRosterPayload + path: { + app_id: string + node_id: string + } + query?: never + url: '/apps/{app_id}/workflows/draft/nodes/{node_id}/agent-composer/copy-from-roster' +} + +export type PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponses = { + 200: WorkflowAgentComposerResponse +} + +export type PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponse + = PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponses[keyof PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponses] + export type PostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactData = { body: ComposerSavePayload path: { diff --git a/packages/contracts/generated/api/console/apps/zod.gen.ts b/packages/contracts/generated/api/console/apps/zod.gen.ts index a7c3a0656e1..ffbfabb0872 100644 --- a/packages/contracts/generated/api/console/apps/zod.gen.ts +++ b/packages/contracts/generated/api/console/apps/zod.gen.ts @@ -642,6 +642,15 @@ export const zHumanInputDeliveryTestPayload = z.object({ */ export const zEmptyObjectResponse = z.record(z.string(), z.unknown()) +/** + * WorkflowComposerCopyFromRosterPayload + */ +export const zWorkflowComposerCopyFromRosterPayload = z.object({ + idempotency_key: z.string().max(255).nullish(), + source_agent_id: z.string().min(1).max(255), + source_snapshot_id: z.string().max(255).nullish(), +}) + /** * DraftWorkflowNodeRunPayload */ @@ -1836,6 +1845,13 @@ export const zComposerBindingPayload = z.object({ current_snapshot_id: z.string().nullish(), }) +/** + * AgentIconType + * + * Supported icon storage formats for Agent roster entries. + */ +export const zAgentIconType = z.enum(['emoji', 'image', 'link']) + /** * ComposerSoulLockPayload */ @@ -3337,9 +3353,14 @@ export const zComposerSavePayload = z.object({ agent_soul: zAgentSoulConfig.nullish(), binding: zComposerBindingPayload.nullish(), client_revision_id: z.string().nullish(), + description: z.string().nullish(), + icon: z.string().max(255).nullish(), + icon_background: z.string().max(255).nullish(), + icon_type: zAgentIconType.nullish(), idempotency_key: z.string().nullish(), new_agent_name: z.string().min(1).max(255).nullish(), node_job: zWorkflowNodeJobConfig.nullish(), + role: z.string().max(255).nullish(), save_strategy: zComposerSaveStrategy, soul_lock: zComposerSoulLockPayload.optional(), variant: zComposerVariant, @@ -5343,6 +5364,20 @@ export const zGetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesPa export const zGetAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCandidatesResponse = zAgentComposerCandidatesResponse +export const zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterBody + = zWorkflowComposerCopyFromRosterPayload + +export const zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterPath = z.object({ + app_id: z.uuid(), + node_id: z.string(), +}) + +/** + * Workflow roster agent copied to inline agent + */ +export const zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerCopyFromRosterResponse + = zWorkflowAgentComposerResponse + export const zPostAppsByAppIdWorkflowsDraftNodesByNodeIdAgentComposerImpactBody = zComposerSavePayload diff --git a/packages/contracts/generated/api/console/installed-apps/types.gen.ts b/packages/contracts/generated/api/console/installed-apps/types.gen.ts index 75cc4efa2c4..f9a5eb01edc 100644 --- a/packages/contracts/generated/api/console/installed-apps/types.gen.ts +++ b/packages/contracts/generated/api/console/installed-apps/types.gen.ts @@ -98,8 +98,8 @@ export type ResultResponse = { result: string } -export type MessageInfiniteScrollPagination = { - data: Array +export type ExploreMessageInfiniteScrollPagination = { + data: Array has_more: boolean limit: number } @@ -187,7 +187,7 @@ export type JsonValue | Array | null -export type MessageListItem = { +export type ExploreMessageListItem = { agent_thoughts: Array answer: string conversation_id: string @@ -200,6 +200,7 @@ export type MessageListItem = { [key: string]: JsonValueType } message_files: Array + metadata?: JsonValueType | null parent_message_id?: string | null query: string retriever_resources: Array @@ -644,7 +645,7 @@ export type GetInstalledAppsByInstalledAppIdMessagesData = { } export type GetInstalledAppsByInstalledAppIdMessagesResponses = { - 200: MessageInfiniteScrollPagination + 200: ExploreMessageInfiniteScrollPagination } export type GetInstalledAppsByInstalledAppIdMessagesResponse diff --git a/packages/contracts/generated/api/console/installed-apps/zod.gen.ts b/packages/contracts/generated/api/console/installed-apps/zod.gen.ts index a9a06fe3f7e..a4556058506 100644 --- a/packages/contracts/generated/api/console/installed-apps/zod.gen.ts +++ b/packages/contracts/generated/api/console/installed-apps/zod.gen.ts @@ -503,9 +503,9 @@ export const zHumanInputContent = z.object({ }) /** - * MessageListItem + * ExploreMessageListItem */ -export const zMessageListItem = z.object({ +export const zExploreMessageListItem = z.object({ agent_thoughts: z.array(zAgentThought), answer: z.string(), conversation_id: z.string(), @@ -516,6 +516,7 @@ export const zMessageListItem = z.object({ id: z.string(), inputs: z.record(z.string(), zJsonValueType), message_files: z.array(zMessageFile), + metadata: zJsonValueType.nullish(), parent_message_id: z.string().nullish(), query: z.string(), retriever_resources: z.array(zRetrieverResource), @@ -523,10 +524,10 @@ export const zMessageListItem = z.object({ }) /** - * MessageInfiniteScrollPagination + * ExploreMessageInfiniteScrollPagination */ -export const zMessageInfiniteScrollPagination = z.object({ - data: z.array(zMessageListItem), +export const zExploreMessageInfiniteScrollPagination = z.object({ + data: z.array(zExploreMessageListItem), has_more: z.boolean(), limit: z.int(), }) @@ -693,7 +694,8 @@ export const zGetInstalledAppsByInstalledAppIdMessagesQuery = z.object({ /** * Success */ -export const zGetInstalledAppsByInstalledAppIdMessagesResponse = zMessageInfiniteScrollPagination +export const zGetInstalledAppsByInstalledAppIdMessagesResponse + = zExploreMessageInfiniteScrollPagination export const zPostInstalledAppsByInstalledAppIdMessagesByMessageIdFeedbacksBody = zMessageFeedbackPayload diff --git a/packages/contracts/generated/api/console/system-features/types.gen.ts b/packages/contracts/generated/api/console/system-features/types.gen.ts index f1dcc7fc4b4..a510faa22da 100644 --- a/packages/contracts/generated/api/console/system-features/types.gen.ts +++ b/packages/contracts/generated/api/console/system-features/types.gen.ts @@ -13,6 +13,7 @@ export type SystemFeatureModel = { enable_email_code_login: boolean enable_email_password_login: boolean enable_explore_banner: boolean + enable_learn_app: boolean enable_marketplace: boolean enable_social_oauth_login: boolean enable_trial_app: boolean diff --git a/packages/contracts/generated/api/console/system-features/zod.gen.ts b/packages/contracts/generated/api/console/system-features/zod.gen.ts index 80b27e7843a..e6f2b2fc5a7 100644 --- a/packages/contracts/generated/api/console/system-features/zod.gen.ts +++ b/packages/contracts/generated/api/console/system-features/zod.gen.ts @@ -105,6 +105,7 @@ export const zSystemFeatureModel = z.object({ enable_email_code_login: z.boolean().default(false), enable_email_password_login: z.boolean().default(true), enable_explore_banner: z.boolean().default(false), + enable_learn_app: z.boolean().default(true), enable_marketplace: z.boolean().default(false), enable_social_oauth_login: z.boolean().default(false), enable_trial_app: z.boolean().default(false), diff --git a/packages/contracts/generated/api/web/types.gen.ts b/packages/contracts/generated/api/web/types.gen.ts index 61c9cf103be..722b3042841 100644 --- a/packages/contracts/generated/api/web/types.gen.ts +++ b/packages/contracts/generated/api/web/types.gen.ts @@ -556,6 +556,7 @@ export type SystemFeatureModel = { enable_email_code_login: boolean enable_email_password_login: boolean enable_explore_banner: boolean + enable_learn_app: boolean enable_marketplace: boolean enable_social_oauth_login: boolean enable_trial_app: boolean diff --git a/packages/contracts/generated/api/web/zod.gen.ts b/packages/contracts/generated/api/web/zod.gen.ts index 8045ad341e3..d555ad5f85c 100644 --- a/packages/contracts/generated/api/web/zod.gen.ts +++ b/packages/contracts/generated/api/web/zod.gen.ts @@ -840,6 +840,7 @@ export const zSystemFeatureModel = z.object({ enable_email_code_login: z.boolean().default(false), enable_email_password_login: z.boolean().default(true), enable_explore_banner: z.boolean().default(false), + enable_learn_app: z.boolean().default(true), enable_marketplace: z.boolean().default(false), enable_social_oauth_login: z.boolean().default(false), enable_trial_app: z.boolean().default(false), diff --git a/packages/contracts/generated/enterprise/orpc.gen.ts b/packages/contracts/generated/enterprise/orpc.gen.ts index 61503a7f742..764cea4fdc1 100644 --- a/packages/contracts/generated/enterprise/orpc.gen.ts +++ b/packages/contracts/generated/enterprise/orpc.gen.ts @@ -380,12 +380,8 @@ export const listRollbackTargets = oc ) .output(zDeploymentServiceListRollbackTargetsResponse) -/** - * CancelDeployment cancels the in-flight deployment on the environment. - */ export const cancelDeployment = oc .route({ - description: 'CancelDeployment cancels the in-flight deployment on the environment.', inputStructure: 'detailed', method: 'POST', operationId: 'DeploymentService_CancelDeployment', @@ -607,13 +603,8 @@ export const releaseService = { precheckRelease, } -/** - * ListEnvironments returns only the environments the current user can - * deploy to. - */ export const listEnvironments = oc .route({ - description: 'ListEnvironments returns only the environments the current user can\n deploy to.', inputStructure: 'detailed', method: 'GET', operationId: 'EnvironmentService_ListEnvironments', diff --git a/packages/contracts/generated/enterprise/types.gen.ts b/packages/contracts/generated/enterprise/types.gen.ts index 600f8975678..882a85465b9 100644 --- a/packages/contracts/generated/enterprise/types.gen.ts +++ b/packages/contracts/generated/enterprise/types.gen.ts @@ -13,13 +13,13 @@ export const AccessMode = { export type AccessMode = (typeof AccessMode)[keyof typeof AccessMode] -export const SubjectType = { - SUBJECT_TYPE_UNSPECIFIED: 'SUBJECT_TYPE_UNSPECIFIED', - SUBJECT_TYPE_ACCOUNT: 'SUBJECT_TYPE_ACCOUNT', - SUBJECT_TYPE_GROUP: 'SUBJECT_TYPE_GROUP', +export const AccessSubjectType = { + ACCESS_SUBJECT_TYPE_UNSPECIFIED: 'ACCESS_SUBJECT_TYPE_UNSPECIFIED', + ACCESS_SUBJECT_TYPE_ACCOUNT: 'ACCESS_SUBJECT_TYPE_ACCOUNT', + ACCESS_SUBJECT_TYPE_GROUP: 'ACCESS_SUBJECT_TYPE_GROUP', } as const -export type SubjectType = (typeof SubjectType)[keyof typeof SubjectType] +export type AccessSubjectType = (typeof AccessSubjectType)[keyof typeof AccessSubjectType] export const AppRunnerLogStatus = { APP_RUNNER_LOG_STATUS_UNSPECIFIED: 'APP_RUNNER_LOG_STATUS_UNSPECIFIED', @@ -295,7 +295,7 @@ export type AccessPolicy = { } export type AccessSubject = { - subjectType: SubjectType + subjectType: AccessSubjectType subjectId: string } @@ -598,7 +598,6 @@ export type Environment = { status: EnvironmentStatus statusMessage: string lastError?: Error - apiServer?: string namespace?: string managedBy?: string runtimeEndpoint?: string @@ -741,9 +740,6 @@ export type GetReleaseResponse = { export type K8sEnvironmentConfig = { namespace?: string - apiServer?: string - caBundle?: string - bearerToken?: string } export type ListApiKeysResponse = { @@ -832,6 +828,7 @@ export type PrecheckReleaseResponse = { canCreate: boolean matchedRelease?: ReleaseContentMatch unsupportedNodes: Array + unsupportedToolProviders: Array } export type PromoteRequest = { @@ -998,6 +995,14 @@ export type UnsupportedDslNode = { type: string } +export type UnsupportedToolProvider = { + nodeId: string + providerType: string + providerId?: string + providerName?: string + toolName?: string +} + export type UpdateAccessChannelsRequest = { appInstanceId?: string webAppEnabled?: boolean @@ -1362,7 +1367,6 @@ export type InfoConfigReply = { Branding?: BrandingInfo WebAppAuth?: WebAppAuthInfo PluginInstallationPermission?: PluginInstallationPermissionInfo - EnableAppDeploy?: boolean } export type InnerAdmission = { @@ -1458,6 +1462,19 @@ export type IsUserAllowedToAccessWebAppRes = { result?: boolean } +export type IssueMcpTokenReply = { + token?: string + expiresAt?: string + tokenType?: string +} + +export type IssueMcpTokenReq = { + userId?: string + tenantId?: string + appId?: string + audience?: string +} + export type JoinWorkspaceReply = { message?: string } @@ -1466,6 +1483,7 @@ export type JoinWorkspaceReq = { id?: string email?: string role?: string + rbacRole?: string } export type LicenseInfo = { @@ -1667,12 +1685,9 @@ export type PluginInstallationSettingsReply = { export type RbacRole = { id?: string - type?: string name?: string description?: string - isBuiltin?: boolean - category?: string - permissionKeys?: Array + permissions?: Array } export type ResetMemberPasswordReply = { @@ -1813,7 +1828,7 @@ export type SetDefaultWorkspaceReq = { export type Subject = { subjectId?: string - subjectType?: SubjectType + subjectType?: string accountData?: SubjectAccountData groupData?: SubjectGroupData } diff --git a/packages/contracts/generated/enterprise/zod.gen.ts b/packages/contracts/generated/enterprise/zod.gen.ts index d7a42b35d4c..85f74b22121 100644 --- a/packages/contracts/generated/enterprise/zod.gen.ts +++ b/packages/contracts/generated/enterprise/zod.gen.ts @@ -9,10 +9,10 @@ export const zAccessMode = z.enum([ 'ACCESS_MODE_PRIVATE_ALL', ]) -export const zSubjectType = z.enum([ - 'SUBJECT_TYPE_UNSPECIFIED', - 'SUBJECT_TYPE_ACCOUNT', - 'SUBJECT_TYPE_GROUP', +export const zAccessSubjectType = z.enum([ + 'ACCESS_SUBJECT_TYPE_UNSPECIFIED', + 'ACCESS_SUBJECT_TYPE_ACCOUNT', + 'ACCESS_SUBJECT_TYPE_GROUP', ]) export const zAppRunnerLogStatus = z.enum([ @@ -203,7 +203,7 @@ export const zLimitStatus = z.enum([ ]) export const zAccessSubject = z.object({ - subjectType: zSubjectType, + subjectType: zAccessSubjectType, subjectId: z.string(), }) @@ -254,10 +254,6 @@ export const zAppInstance = z.object({ updatedAt: z.iso.datetime(), }) -/** - * BootstrapAssignment is one runtime_instance assignment in a runner's startup - * baseline. - */ export const zBootstrapAssignment = z.object({ appId: z.string().optional(), environmentId: z.string().optional(), @@ -322,10 +318,6 @@ export const zCreateReleaseRequest = z.object({ sourceAppId: z.string().optional(), }) -/** - * CredentialCandidate is one tenant-visible credential a frontend may - * pick for a credential slot. It carries no secret. - */ export const zCredentialCandidate = z.object({ credentialId: z.string(), providerId: z.string(), @@ -334,20 +326,12 @@ export const zCredentialCandidate = z.object({ fromEnterprise: z.boolean(), }) -/** - * CredentialSelectionInput is one deploy-time plugin-credential - * selection: a shared credential id chosen for a required DSL slot. - */ export const zCredentialSelectionInput = z.object({ providerId: z.string(), category: zPluginCategory.optional(), credentialId: z.string(), }) -/** - * CredentialSlot is one model/tool plugin-credential requirement a - * Release's DSL declares, paired with the candidates selectable for it. - */ export const zCredentialSlot = z.object({ providerId: z.string(), category: zPluginCategory, @@ -406,10 +390,6 @@ export const zEnvironmentDeploymentRecord = z.object({ finalizedAt: z.iso.datetime().optional(), }) -/** - * Error is the package-wide failure shape, carried wherever an operation or - * resource reports an error. - */ export const zError = z.object({ code: z.string().optional(), message: z.string().optional(), @@ -445,7 +425,6 @@ export const zEnvironment = z.object({ status: zEnvironmentStatus, statusMessage: z.string(), lastError: zError.optional(), - apiServer: z.string().optional(), namespace: z.string().optional(), managedBy: z.string().optional(), runtimeEndpoint: z.string().optional(), @@ -523,9 +502,6 @@ export const zGetEnvironmentResponse = z.object({ export const zK8sEnvironmentConfig = z.object({ namespace: z.string().optional(), - apiServer: z.string().optional(), - caBundle: z.string().optional(), - bearerToken: z.string().optional(), }) export const zCreateEnvironmentRequest = z.object({ @@ -571,9 +547,6 @@ export const zDeployRequest = z.object({ expectedDslDigest: z.string().optional(), }) -/** - * Operator is who triggered the run (the "END USER OR ACCOUNT" column). - */ export const zOperator = z.object({ type: zOperatorType, id: z.string(), @@ -620,10 +593,6 @@ export const zPromoteRequest = z.object({ idempotencyKey: z.string(), }) -/** - * ReleaseContentMatch identifies an existing release whose DSL content is - * identical to the checked content. - */ export const zReleaseContentMatch = z.object({ releaseId: z.string(), displayName: z.string(), @@ -638,11 +607,6 @@ export const zReleaseEnvironmentAction = z.object({ currentReleaseId: z.string(), }) -/** - * ReleaseEnvironmentDeployment is an environment where the release is the - * active deployment, paired with that environment's runtime status so the - * version history can show running vs failed vs deploying. - */ export const zReleaseEnvironmentDeployment = z.object({ environment: zEnvironment, status: zRuntimeInstanceStatus, @@ -663,10 +627,6 @@ export const zReportRuntimeAssignmentStatusResponse = z.object({ stale: z.boolean().optional(), }) -/** - * RequiredSlot is an input requirement extracted from a Release's - * DSL. - */ export const zRequiredSlot = z.object({ type: zSlotType, providerId: z.string(), @@ -715,10 +675,6 @@ export const zDeployResponse = z.object({ deployment: zDeployment, }) -/** - * EnvironmentAppInstance is one app instance as seen from a single environment: - * its current release, runtime status, and derived last error in THIS env. - */ export const zEnvironmentAppInstance = z.object({ appInstance: zAppInstance.optional(), currentRelease: zRelease.optional(), @@ -759,10 +715,6 @@ export const zComputeReleaseDeploymentViewResponse = z.object({ options: zDeploymentOptions.optional(), }) -/** - * EnvironmentDeploymentHistoryItem is one deployment row in an environment's - * history, with a thin reference to the owning app instance. - */ export const zEnvironmentDeploymentHistoryItem = z.object({ deployment: zDeployment.optional(), appInstanceId: z.string().optional(), @@ -904,20 +856,25 @@ export const zUndeployResponse = z.object({ deployment: zDeployment, }) -/** - * UnsupportedDslNode identifies a workflow node whose type the app runner - * cannot execute. - */ export const zUnsupportedDslNode = z.object({ id: z.string(), type: z.string(), }) +export const zUnsupportedToolProvider = z.object({ + nodeId: z.string(), + providerType: z.string(), + providerId: z.string().optional(), + providerName: z.string().optional(), + toolName: z.string().optional(), +}) + export const zPrecheckReleaseResponse = z.object({ gateCommitId: z.string(), canCreate: z.boolean(), matchedRelease: zReleaseContentMatch.optional(), unsupportedNodes: z.array(zUnsupportedDslNode), + unsupportedToolProviders: z.array(zUnsupportedToolProvider), }) export const zUpdateAccessChannelsRequest = z.object({ @@ -1302,6 +1259,19 @@ export const zIsUserAllowedToAccessWebAppRes = z.object({ result: z.boolean().optional(), }) +export const zIssueMcpTokenReply = z.object({ + token: z.string().optional(), + expiresAt: z.string().optional(), + tokenType: z.string().optional(), +}) + +export const zIssueMcpTokenReq = z.object({ + userId: z.string().optional(), + tenantId: z.string().optional(), + appId: z.string().optional(), + audience: z.string().optional(), +}) + export const zJoinWorkspaceReply = z.object({ message: z.string().optional(), }) @@ -1313,6 +1283,7 @@ export const zJoinWorkspaceReq = z.object({ id: z.string().optional(), email: z.string().optional(), role: z.string().optional(), + rbacRole: z.string().optional(), }) export const zLimitConfig = z.object({ @@ -1494,12 +1465,9 @@ export const zPluginInstallationSettingsReply = z.object({ export const zRbacRole = z.object({ id: z.string().optional(), - type: z.string().optional(), name: z.string().optional(), description: z.string().optional(), - isBuiltin: z.boolean().optional(), - category: z.string().optional(), - permissionKeys: z.array(z.string()).optional(), + permissions: z.array(z.string()).optional(), }) export const zGetMemberRbacRolesReply = z.object({ @@ -1778,7 +1746,7 @@ export const zGetWebAppWhitelistSubjectsRes = z.object({ */ export const zSubject = z.object({ subjectId: z.string().optional(), - subjectType: zSubjectType.optional(), + subjectType: z.string().optional(), accountData: zSubjectAccountData.optional(), groupData: zSubjectGroupData.optional(), }) @@ -2104,7 +2072,6 @@ export const zInfoConfigReply = z.object({ Branding: zBrandingInfo.optional(), WebAppAuth: zWebAppAuthInfo.optional(), PluginInstallationPermission: zPluginInstallationPermissionInfo.optional(), - EnableAppDeploy: z.boolean().optional(), }) export const zWebOAuth2LoginReply = z.object({ diff --git a/packages/dify-ui/src/switch/__tests__/index.spec.tsx b/packages/dify-ui/src/switch/__tests__/index.spec.tsx index 28aa8a655ce..0e289539581 100644 --- a/packages/dify-ui/src/switch/__tests__/index.spec.tsx +++ b/packages/dify-ui/src/switch/__tests__/index.spec.tsx @@ -123,9 +123,13 @@ describe('Switch', () => { await expect.element(screen.getByRole('switch')).toHaveAttribute('data-checked', '') }) - it('should have focus-visible ring-3 styles', async () => { + it('should replace the native focus outline with the accent focus ring', async () => { const screen = await render() - await expect.element(screen.getByRole('switch')).toHaveClass('focus-visible:ring-2') + await expect.element(screen.getByRole('switch')).toHaveClass( + 'outline-hidden', + 'focus-visible:ring-2', + 'focus-visible:ring-state-accent-solid', + ) }) it('should respect prefers-reduced-motion', async () => { diff --git a/packages/dify-ui/src/switch/index.tsx b/packages/dify-ui/src/switch/index.tsx index d35809ecad8..768e009488e 100644 --- a/packages/dify-ui/src/switch/index.tsx +++ b/packages/dify-ui/src/switch/index.tsx @@ -10,7 +10,7 @@ import { cn } from '../cn' const switchRootStateClassName = 'bg-components-toggle-bg-unchecked hover:bg-components-toggle-bg-unchecked-hover data-checked:bg-components-toggle-bg data-checked:hover:bg-components-toggle-bg-hover data-disabled:cursor-not-allowed data-disabled:bg-components-toggle-bg-unchecked-disabled data-disabled:hover:bg-components-toggle-bg-unchecked-disabled data-disabled:data-checked:bg-components-toggle-bg-disabled data-disabled:data-checked:hover:bg-components-toggle-bg-disabled' const switchRootVariants = cva( - `group relative inline-flex shrink-0 cursor-pointer touch-manipulation items-center transition-colors duration-200 ease-in-out focus-visible:ring-2 focus-visible:ring-state-accent-solid motion-reduce:transition-none ${switchRootStateClassName}`, + `group relative inline-flex shrink-0 cursor-pointer touch-manipulation items-center outline-hidden transition-colors duration-200 ease-in-out focus-visible:ring-2 focus-visible:ring-state-accent-solid motion-reduce:transition-none ${switchRootStateClassName}`, { variants: { size: { diff --git a/packages/dify-ui/src/themes/dark.css b/packages/dify-ui/src/themes/dark.css index 1b24e8fb489..3f4a163a725 100644 --- a/packages/dify-ui/src/themes/dark.css +++ b/packages/dify-ui/src/themes/dark.css @@ -162,6 +162,7 @@ html[data-theme="dark"] { --color-components-main-nav-glass-surface-middle-2: #0033ff1a; --color-components-main-nav-glass-surface-end: #0033ff14; --color-components-main-nav-glass-edge-highlight-first: #fffffffa; + --color-components-main-nav-glass-edge-highlight-middle: #ffffff00; --color-components-main-nav-glass-edge-highlight-end: #ffffff6b; --color-components-main-nav-glass-edge-reflection-first: #0033ff00; --color-components-main-nav-glass-edge-reflection-middle: #0033ff99; diff --git a/packages/dify-ui/src/themes/light.css b/packages/dify-ui/src/themes/light.css index 3feb4afb47f..dd3252f3614 100644 --- a/packages/dify-ui/src/themes/light.css +++ b/packages/dify-ui/src/themes/light.css @@ -162,6 +162,7 @@ html[data-theme="light"] { --color-components-main-nav-glass-surface-middle-2: #0033ff1a; --color-components-main-nav-glass-surface-end: #0033ff14; --color-components-main-nav-glass-edge-highlight-first: #fffffffa; + --color-components-main-nav-glass-edge-highlight-middle: #ffffff00; --color-components-main-nav-glass-edge-highlight-end: #ffffff6b; --color-components-main-nav-glass-edge-reflection-first: #0033ff00; --color-components-main-nav-glass-edge-reflection-middle: #0033ff99; diff --git a/packages/dify-ui/src/themes/theme.css b/packages/dify-ui/src/themes/theme.css index 3e35feb8eb8..c14e54ea549 100644 --- a/packages/dify-ui/src/themes/theme.css +++ b/packages/dify-ui/src/themes/theme.css @@ -169,6 +169,7 @@ --color-components-main-nav-glass-surface-middle-2: var(--color-components-main-nav-glass-surface-middle-2); --color-components-main-nav-glass-surface-end: var(--color-components-main-nav-glass-surface-end); --color-components-main-nav-glass-edge-highlight-first: var(--color-components-main-nav-glass-edge-highlight-first); + --color-components-main-nav-glass-edge-highlight-middle: var(--color-components-main-nav-glass-edge-highlight-middle); --color-components-main-nav-glass-edge-highlight-end: var(--color-components-main-nav-glass-edge-highlight-end); --color-components-main-nav-glass-edge-reflection-first: var(--color-components-main-nav-glass-edge-reflection-first); --color-components-main-nav-glass-edge-reflection-middle: var(--color-components-main-nav-glass-edge-reflection-middle); diff --git a/web/.env.example b/web/.env.example index 112232e529c..7363ce628f3 100644 --- a/web/.env.example +++ b/web/.env.example @@ -116,3 +116,4 @@ NEXT_PUBLIC_ENABLE_CHANGE_EMAIL=true NEXT_PUBLIC_CREATORS_PLATFORM_FEATURES_ENABLED=true NEXT_PUBLIC_ENABLE_TRIAL_APP=true NEXT_PUBLIC_ENABLE_EXPLORE_BANNER=true +NEXT_PUBLIC_RBAC_ENABLED=false diff --git a/web/.storybook/utils/form-story-wrapper.tsx b/web/.storybook/utils/form-story-wrapper.tsx deleted file mode 100644 index 7503e9905d3..00000000000 --- a/web/.storybook/utils/form-story-wrapper.tsx +++ /dev/null @@ -1,88 +0,0 @@ -import type { ReactNode } from 'react' -import { useStore } from '@tanstack/react-form' -import { useState } from 'react' -import { useAppForm } from '@/app/components/base/form' - -type UseAppFormOptions = Parameters[0] -type AppFormInstance = ReturnType - -type FormStoryWrapperProps = { - options?: UseAppFormOptions - children: (form: AppFormInstance) => ReactNode - title?: string - subtitle?: string -} - -export const FormStoryWrapper = ({ - options, - children, - title, - subtitle, -}: FormStoryWrapperProps) => { - const [lastSubmitted, setLastSubmitted] = useState(null) - const [submitCount, setSubmitCount] = useState(0) - - const form = useAppForm({ - ...options, - onSubmit: (context) => { - setSubmitCount(count => count + 1) - setLastSubmitted(context.value) - options?.onSubmit?.(context) - }, - }) - - const values = useStore(form.store, state => state.values) - const isSubmitting = useStore(form.store, state => state.isSubmitting) - const canSubmit = useStore(form.store, state => state.canSubmit) - - return ( -
-
- {(title || subtitle) && ( -
- {title &&

{title}

} - {subtitle &&

{subtitle}

} -
- )} - {children(form)} -
- -
- ) -} - -export type FormStoryRender = (form: AppFormInstance) => ReactNode diff --git a/web/__tests__/app-sidebar/sidebar-shell-flow.test.tsx b/web/__tests__/app-sidebar/sidebar-shell-flow.test.tsx deleted file mode 100644 index c1f3a1c62d4..00000000000 --- a/web/__tests__/app-sidebar/sidebar-shell-flow.test.tsx +++ /dev/null @@ -1,151 +0,0 @@ -import type { SVGProps } from 'react' -import { fireEvent, render, screen } from '@testing-library/react' -import * as React from 'react' -import { beforeEach, describe, expect, it, vi } from 'vitest' -import AppDetailNav from '@/app/components/app-sidebar' - -const mockSetDetailSidebarMode = vi.fn() - -let mockDetailSidebarMode = 'expand' -let mockPathname = '/app/app-1/logs' -let mockSelectedSegment = 'logs' -let mockIsHovering = true -let hotkeyHandler: ((event: { preventDefault: () => void }) => void) | null = null - -vi.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - -vi.mock('@/app/components/main-nav/storage', () => ({ - useDetailSidebarMode: () => [mockDetailSidebarMode, mockSetDetailSidebarMode], -})) - -vi.mock('@/next/navigation', () => ({ - usePathname: () => mockPathname, - useSelectedLayoutSegment: () => mockSelectedSegment, -})) - -vi.mock('@/next/link', () => ({ - default: ({ - href, - children, - className, - title, - }: { - href: string - children?: React.ReactNode - className?: string - title?: string - }) => ( - - {children} - - ), -})) - -vi.mock('ahooks', () => ({ - useHover: () => mockIsHovering, -})) - -vi.mock('@tanstack/react-hotkeys', async (importOriginal) => { - const actual = await importOriginal() - return { - ...actual, - useHotkey: (_hotkey: string, handler: (event: { preventDefault: () => void }) => void) => { - hotkeyHandler = handler - }, - } -}) - -vi.mock('@/hooks/use-breakpoints', () => ({ - default: () => 'desktop', - MediaType: { - mobile: 'mobile', - desktop: 'desktop', - }, -})) - -vi.mock('@/context/event-emitter', () => ({ - useEventEmitterContextContext: () => ({ - eventEmitter: { - useSubscription: vi.fn(), - }, - }), -})) - -vi.mock('@/context/app-context', () => ({ - useAppContext: () => ({ - isCurrentWorkspaceEditor: true, - }), -})) - -vi.mock('@langgenius/dify-ui/dropdown-menu', () => import('@/__mocks__/base-ui-dropdown-menu')) -vi.mock('@langgenius/dify-ui/tooltip', () => import('@/__mocks__/base-ui-tooltip')) - -vi.mock('@/app/components/app-sidebar/app-info', () => ({ - default: ({ - expand, - onlyShowDetail, - openState, - }: { - expand: boolean - onlyShowDetail?: boolean - openState?: boolean - }) => ( -
- ), -})) - -const MockIcon = (props: SVGProps) => - -const navigation = [ - { name: 'Overview', href: '/app/app-1/overview', icon: MockIcon, selectedIcon: MockIcon }, - { name: 'Logs', href: '/app/app-1/logs', icon: MockIcon, selectedIcon: MockIcon }, -] - -describe('App Sidebar Shell Flow', () => { - beforeEach(() => { - vi.clearAllMocks() - localStorage.clear() - mockDetailSidebarMode = 'expand' - mockPathname = '/app/app-1/logs' - mockSelectedSegment = 'logs' - mockIsHovering = true - hotkeyHandler = null - }) - - it('renders the expanded sidebar, marks the active nav item, and toggles collapse by click and shortcut', () => { - render() - - expect(screen.getByTestId('app-info')).toHaveAttribute('data-expand', 'true') - - const logsLink = screen.getByRole('link', { name: /Logs/i }) - expect(logsLink.className).toContain('bg-components-menu-item-bg-active') - - fireEvent.click(screen.getByRole('button')) - expect(mockSetDetailSidebarMode).toHaveBeenCalledWith('collapse') - - const preventDefault = vi.fn() - hotkeyHandler?.({ preventDefault }) - - expect(preventDefault).toHaveBeenCalled() - expect(mockSetDetailSidebarMode).toHaveBeenCalledWith('collapse') - }) - - it('keeps the normal sidebar on workflow routes', () => { - mockPathname = '/app/app-1/workflow' - mockSelectedSegment = 'workflow' - - render() - - expect(screen.getByTestId('app-info')).toBeInTheDocument() - expect(screen.getByRole('link', { name: /Overview/i })).toBeInTheDocument() - expect(screen.getByRole('link', { name: /Logs/i })).toBeInTheDocument() - }) -}) diff --git a/web/__tests__/app/app-access-control-flow.test.tsx b/web/__tests__/app/app-access-control-flow.test.tsx index 3ab3587f258..4ecfde0158c 100644 --- a/web/__tests__/app/app-access-control-flow.test.tsx +++ b/web/__tests__/app/app-access-control-flow.test.tsx @@ -5,7 +5,7 @@ import { AppPublisher } from '@/app/components/app/app-publisher' import { AccessMode } from '@/models/access-control' import { AppModeEnum } from '@/types/app' -const mockFetchAppDetailDirect = vi.fn() +const mockFetchAppDetail = vi.fn() const mockSetAppDetail = vi.fn() const mockRefetch = vi.fn() @@ -69,7 +69,7 @@ vi.mock('@/service/access-control/use-app-access-control', () => ({ })) vi.mock('@/service/apps', () => ({ - fetchAppDetailDirect: (...args: unknown[]) => mockFetchAppDetailDirect(...args), + fetchAppDetail: (...args: unknown[]) => mockFetchAppDetail(...args), })) vi.mock('@/app/components/app/overview/embedded', () => ({ @@ -120,7 +120,7 @@ describe('App Access Control Flow', () => { access_token: 'token-1', }, } - mockFetchAppDetailDirect.mockResolvedValue({ + mockFetchAppDetail.mockResolvedValue({ ...mockAppDetail, access_mode: AccessMode.PUBLIC, }) @@ -128,7 +128,7 @@ describe('App Access Control Flow', () => { it('refreshes app detail after confirming access control updates', async () => { const { queryClient } = renderWithQueryClient() - const invalidateQueriesSpy = vi.spyOn(queryClient, 'invalidateQueries').mockResolvedValue() + const setQueryDataSpy = vi.spyOn(queryClient, 'setQueryData') fireEvent.click(screen.getByRole('button', { name: 'workflow.common.publish' })) fireEvent.click(screen.getByText('app.accessControlDialog.accessItems.specific')) @@ -138,8 +138,14 @@ describe('App Access Control Flow', () => { fireEvent.click(screen.getByRole('button', { name: 'confirm-access-control' })) await waitFor(() => { - expect(invalidateQueriesSpy).toHaveBeenCalledWith({ queryKey: ['apps', 'detail', 'app-1'] }) + expect(mockFetchAppDetail).toHaveBeenCalledWith({ url: '/apps', id: 'app-1' }) }) + expect(setQueryDataSpy).toHaveBeenCalledWith(['apps', 'detail', 'app-1'], expect.objectContaining({ + access_mode: AccessMode.PUBLIC, + })) + expect(mockSetAppDetail).toHaveBeenCalledWith(expect.objectContaining({ + access_mode: AccessMode.PUBLIC, + })) await waitFor(() => { expect(screen.queryByTestId('access-control-modal')).not.toBeInTheDocument() diff --git a/web/__tests__/base/form-demo-flow.test.tsx b/web/__tests__/base/form-demo-flow.test.tsx deleted file mode 100644 index afb36528c0f..00000000000 --- a/web/__tests__/base/form-demo-flow.test.tsx +++ /dev/null @@ -1,65 +0,0 @@ -import { render, screen, waitFor, within } from '@testing-library/react' -import userEvent from '@testing-library/user-event' -import { beforeEach, describe, expect, it, vi } from 'vitest' -import DemoForm from '@/app/components/base/form/form-scenarios/demo' - -describe('Base Form Demo Flow', () => { - const consoleLogSpy = vi.spyOn(console, 'log').mockImplementation(() => {}) - - beforeEach(() => { - vi.clearAllMocks() - }) - - it('reveals contact fields and submits the composed form values through the shared form actions', async () => { - const user = userEvent.setup() - render() - - expect(screen.queryByRole('heading', { name: /contacts/i })).not.toBeInTheDocument() - - await user.type(screen.getByRole('textbox', { name: /^name$/i }), 'Alice') - await user.type(screen.getByRole('textbox', { name: /^surname$/i }), 'Smith') - await user.click(screen.getByText(/i accept the terms and conditions/i)) - - expect(await screen.findByRole('heading', { name: /contacts/i })).toBeInTheDocument() - - await user.type(screen.getByRole('textbox', { name: /^email$/i }), 'alice@example.com') - - const preferredMethodLabel = screen.getByText('Preferred Contact Method') - const preferredMethodField = preferredMethodLabel.parentElement?.parentElement - expect(preferredMethodField).toBeTruthy() - - await user.click(within(preferredMethodField as HTMLElement).getByText('Email')) - await user.click(screen.getByText('Whatsapp')) - - const submitButton = screen.getByRole('button', { name: /operation\.submit/i }) - expect(submitButton).toBeEnabled() - await user.click(submitButton) - - await waitFor(() => { - expect(consoleLogSpy).toHaveBeenCalledWith('Form submitted:', expect.objectContaining({ - name: 'Alice', - surname: 'Smith', - isAcceptingTerms: true, - contact: expect.objectContaining({ - email: 'alice@example.com', - preferredContactMethod: 'whatsapp', - }), - })) - }) - }) - - it('removes the nested contact section again when the name field is cleared', async () => { - const user = userEvent.setup() - render() - - const nameInput = screen.getByRole('textbox', { name: /^name$/i }) - await user.type(nameInput, 'Alice') - expect(await screen.findByRole('heading', { name: /contacts/i })).toBeInTheDocument() - - await user.clear(nameInput) - - await waitFor(() => { - expect(screen.queryByRole('heading', { name: /contacts/i })).not.toBeInTheDocument() - }) - }) -}) diff --git a/web/__tests__/billing/billing-integration.test.tsx b/web/__tests__/billing/billing-integration.test.tsx index 2fde9b506f4..e82a6d8eac2 100644 --- a/web/__tests__/billing/billing-integration.test.tsx +++ b/web/__tests__/billing/billing-integration.test.tsx @@ -7,7 +7,6 @@ import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' import AppsFull from '@/app/components/billing/apps-full-in-dialog' import Billing from '@/app/components/billing/billing-page' import { defaultPlan, NUM_INFINITE } from '@/app/components/billing/config' -import HeaderBillingBtn from '@/app/components/billing/header-billing-btn' import PlanComp from '@/app/components/billing/plan' import { PlanUpgradeModal } from '@/app/components/billing/plan-upgrade-modal' import PriorityLabel from '@/app/components/billing/priority-label' @@ -696,83 +695,7 @@ describe('Capacity Full Components Integration', () => { }) // ═══════════════════════════════════════════════════════════════════════════ -// 5. Header Billing Button Integration -// Tests HeaderBillingBtn behavior for different plan states -// ═══════════════════════════════════════════════════════════════════════════ -describe('Header Billing Button Integration', () => { - beforeEach(() => { - vi.clearAllMocks() - setupAppContext() - }) - - it('should render UpgradeBtn (premium badge) for sandbox plan', () => { - setupProviderContext({ type: Plan.sandbox }) - - render() - - expect(screen.getByText(/upgradeBtn\.encourageShort/i)).toBeInTheDocument() - }) - - it('should render "pro" badge for professional plan', () => { - setupProviderContext({ type: Plan.professional }) - - render() - - expect(screen.getByText('pro')).toBeInTheDocument() - expect(screen.queryByText(/upgradeBtn/i)).not.toBeInTheDocument() - }) - - it('should render "team" badge for team plan', () => { - setupProviderContext({ type: Plan.team }) - - render() - - expect(screen.getByText('team')).toBeInTheDocument() - }) - - it('should return null when billing is disabled', () => { - setupProviderContext({ type: Plan.sandbox }, { enableBilling: false }) - - const { container } = render() - - expect(container.innerHTML).toBe('') - }) - - it('should return null when plan is not fetched yet', () => { - setupProviderContext({ type: Plan.sandbox }, { isFetchedPlan: false }) - - const { container } = render() - - expect(container.innerHTML).toBe('') - }) - - it('should call onClick when clicking pro/team badge in non-display-only mode', async () => { - const user = userEvent.setup() - const onClick = vi.fn() - setupProviderContext({ type: Plan.professional }) - - render() - - await user.click(screen.getByText('pro')) - - expect(onClick).toHaveBeenCalledTimes(1) - }) - - it('should not call onClick when isDisplayOnly is true', async () => { - const user = userEvent.setup() - const onClick = vi.fn() - setupProviderContext({ type: Plan.professional }) - - render() - - await user.click(screen.getByText('pro')) - - expect(onClick).not.toHaveBeenCalled() - }) -}) - -// ═══════════════════════════════════════════════════════════════════════════ -// 6. PriorityLabel Integration +// 5. PriorityLabel Integration // Tests priority badge display for different plan types // ═══════════════════════════════════════════════════════════════════════════ describe('PriorityLabel Integration', () => { diff --git a/web/__tests__/env.spec.ts b/web/__tests__/env.spec.ts index 89781d32685..419dcadf030 100644 --- a/web/__tests__/env.spec.ts +++ b/web/__tests__/env.spec.ts @@ -1,5 +1,6 @@ describe('env runtime transport', () => { const originalAgentV2Env = process.env.NEXT_PUBLIC_ENABLE_AGENT_V2 + const originalRbacEnv = process.env.NEXT_PUBLIC_RBAC_ENABLED beforeEach(() => { vi.clearAllMocks() @@ -7,7 +8,9 @@ describe('env runtime transport', () => { vi.doUnmock('../utils/client') document.body.removeAttribute('data-enable-agent-v2') document.body.removeAttribute('data-enable-agent-v-2') + document.body.removeAttribute('data-rbac-enabled') delete process.env.NEXT_PUBLIC_ENABLE_AGENT_V2 + delete process.env.NEXT_PUBLIC_RBAC_ENABLED }) afterAll(() => { @@ -15,6 +18,11 @@ describe('env runtime transport', () => { delete process.env.NEXT_PUBLIC_ENABLE_AGENT_V2 else process.env.NEXT_PUBLIC_ENABLE_AGENT_V2 = originalAgentV2Env + + if (originalRbacEnv === undefined) + delete process.env.NEXT_PUBLIC_RBAC_ENABLED + else + process.env.NEXT_PUBLIC_RBAC_ENABLED = originalRbacEnv }) it('should read NEXT_PUBLIC_ENABLE_AGENT_V2 from the browser runtime dataset key', async () => { @@ -25,6 +33,14 @@ describe('env runtime transport', () => { expect(env.NEXT_PUBLIC_ENABLE_AGENT_V2).toBe(true) }) + it('should read NEXT_PUBLIC_RBAC_ENABLED from the browser runtime dataset key', async () => { + document.body.setAttribute('data-rbac-enabled', 'true') + + const { env } = await import('../env') + + expect(env.NEXT_PUBLIC_RBAC_ENABLED).toBe(true) + }) + it('should emit the Agent v2 runtime dataset attribute from getDatasetMap on the server', async () => { process.env.NEXT_PUBLIC_ENABLE_AGENT_V2 = 'true' @@ -39,4 +55,18 @@ describe('env runtime transport', () => { expect(datasetMap['data-enable-agent-v2']).toBe(true) expect(datasetMap['data-enable-agent-v-2']).toBeUndefined() }) + + it('should emit the RBAC runtime dataset attribute from getDatasetMap on the server', async () => { + process.env.NEXT_PUBLIC_RBAC_ENABLED = 'true' + + vi.doMock('../utils/client', () => ({ + isClient: false, + isServer: true, + })) + + const { getDatasetMap } = await import('../env') + const datasetMap = getDatasetMap() + + expect(datasetMap['data-rbac-enabled']).toBe(true) + }) }) diff --git a/web/__tests__/header/nav-flow.test.tsx b/web/__tests__/header/nav-flow.test.tsx deleted file mode 100644 index 33927f19da4..00000000000 --- a/web/__tests__/header/nav-flow.test.tsx +++ /dev/null @@ -1,168 +0,0 @@ -import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import userEvent from '@testing-library/user-event' -import * as React from 'react' -import { beforeEach, describe, expect, it, vi } from 'vitest' -import Nav from '@/app/components/header/nav' -import { AppModeEnum } from '@/types/app' - -const mockPush = vi.fn() -const mockSetAppDetail = vi.fn() -const mockOnCreate = vi.fn() -const mockOnLoadMore = vi.fn() - -let mockSelectedSegment = 'app' -let mockIsCurrentWorkspaceEditor = true - -vi.mock('react-i18next', () => ({ - useTranslation: () => ({ - t: (key: string) => key, - }), -})) - -vi.mock('@/next/navigation', () => ({ - useSelectedLayoutSegment: () => mockSelectedSegment, - useRouter: () => ({ - push: mockPush, - }), -})) - -vi.mock('@/next/link', () => ({ - default: ({ - href, - children, - }: { - href: string - children?: React.ReactNode - }) => {children}, -})) - -vi.mock('@/app/components/app/store', () => ({ - useStore: () => mockSetAppDetail, -})) - -vi.mock('@/context/app-context', () => ({ - useAppContext: () => ({ - isCurrentWorkspaceEditor: mockIsCurrentWorkspaceEditor, - workspacePermissionKeys: mockIsCurrentWorkspaceEditor ? ['app.create_and_management'] : [], - }), -})) - -const navigationItems = [ - { - id: 'app-1', - name: 'Alpha', - link: '/app/app-1/configuration', - icon_type: 'emoji' as const, - icon: '🤖', - icon_background: '#FFEAD5', - icon_url: null, - mode: AppModeEnum.CHAT, - }, - { - id: 'app-2', - name: 'Bravo', - link: '/app/app-2/workflow', - icon_type: 'emoji' as const, - icon: '⚙️', - icon_background: '#E0F2FE', - icon_url: null, - mode: AppModeEnum.WORKFLOW, - }, -] - -const curNav = { - id: 'app-1', - name: 'Alpha', - icon_type: 'emoji' as const, - icon: '🤖', - icon_background: '#FFEAD5', - icon_url: null, - mode: AppModeEnum.CHAT, -} - -const renderNav = (nav = curNav) => { - return render( -