Merge branch 'main' into feat/refine-snippet-siderbar

This commit is contained in:
JzoNg 2026-06-24 17:20:49 +08:00
commit 13c6119aa0
684 changed files with 4565 additions and 34283 deletions

View File

@ -1,71 +1,59 @@
---
name: how-to-write-component
description: React/TypeScript component style guide. Use when writing, refactoring, or reviewing React components, especially around abstraction choices, props typing, state boundaries, shared local state with Jotai atoms, API types, query/mutation contracts, navigation, memoization, wrappers, and empty-state handling.
description: Use when writing, refactoring, or reviewing React/TypeScript components in Dify web, especially decisions about component ownership, props/types, URL/query state, Jotai state, async state, generated API contracts, queries/mutations, overlays, effects, navigation, performance, and empty states.
---
# How To Write A Component
Use this as the decision guide for React/TypeScript component structure. Existing code is reference material, not automatic precedent; when it conflicts with these rules, adapt the approach instead of reproducing the violation.
Use this as the component decision guide for Dify web. Existing code is reference material, not automatic precedent; if touched code violates these rules, adapt it and fix equivalent patterns in the same feature branch.
## First Decisions
| Question | Default | Promote or extract only when |
| --- | --- | --- |
| Where should code live? | Keep it local to the feature workflow, route, or owner. | Multiple verticals need the same stable primitive. |
| Who owns state, data, and handlers? | The lowest component that uses them. | A parent coordinates shared loading, errors, empty UI, selection, submission, navigation, or one consistent snapshot. |
| Should this become Jotai state? | Keep synchronous UI/form state in component or DOM state. | Siblings need one source of truth, the value drives atoms, or scoped workflow state must survive hidden/unmounted steps. |
| Should URL state enter Jotai? | Let Next.js route params and `nuqs` own URL state and updates. | Query atoms or shared derived atoms need a read-only bridge hydrated at the route/surface boundary. |
| Should this query/mutation become an atom? | Use TanStack Query hooks at the lowest owner. | It reads atom state, feeds derived atoms, or participates in shared Jotai workflow orchestration. |
| Should this be a helper/wrapper? | Prefer direct readable code at the use site. | The name captures a stable domain rule or the wrapper owns real behavior, validation, state, error handling, or semantics. |
| Is an Effect needed? | No. Derive during render or handle the user action in the event handler. | It synchronizes with an external system such as browser APIs, subscriptions, timers, analytics, or imperative DOM/non-React widgets. |
## Core Defaults
- Search before adding UI, hooks, helpers, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit.
- Group code by feature workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them.
- Promote code to shared only when multiple verticals need the same stable primitive. Otherwise keep it local and compose shared primitives inside the owning feature.
- Prefer local code and purpose-named helpers over catch-all utility modules; do not group workflow-specific defaults, validation, payload shaping, or metadata merging in a generic utils file just because they share a DTO.
- Keep source/default selection, validation, and payload shaping close to the workflow that owns the behavior. Do not extract a shared helper just because two flows read the same DTO when their priority order, fallback behavior, or submit semantics differ.
- Prefer direct, readable conditionals at the use site for small branch-specific decisions, especially form source selection and request payload assembly. Extract only when the helper name captures a stable domain rule and removes repeated complexity without hiding flow-specific behavior.
- When fixing an invalid pattern, scan the touched feature or branch for equivalent patterns and fix them together.
- Follow Dify's CSS-first Tailwind v4 contract from `packages/dify-ui/README.md` and `packages/dify-ui/AGENTS.md`. Prefer design-system tokens, utilities, and radius mappings over generic Tailwind guidance.
- Search before adding UI, hooks, helpers, query utilities, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit.
- Follow Dify's CSS-first Tailwind v4 contract from `packages/dify-ui/README.md` and `packages/dify-ui/AGENTS.md`. Prefer design-system tokens, utilities, and radius mappings over generic Tailwind choices.
- Group feature code by workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them.
- Keep source/default selection, validation, dirty checks, and payload shaping close to the workflow that owns submit behavior. Do not hide flow-specific priority order, fallback behavior, or submit semantics in generic utilities.
- Prefer direct conditionals for small branch-specific decisions, especially form source selection and request payload assembly.
- Loading states for page sections, cards, lists, tables, forms, and drawers should be skeletons scoped to the content being loaded. Use spinners only for small inline busy indicators.
## Feature Workflow Layout
## Layout And Ownership
- State-heavy wizards, drawers, modals, and secondary workflows work best as a small feature surface with route/entry files, a single feature-local state file, and feature-local UI.
- Keep `ui/` shallow with owner files that map to the workflow's real composition boundaries and major visual regions.
- Owner files contain the section components, field components, skeletons, and one-off helper components that belong to their visual region.
- Folders represent groups of related files with a shared owner and a stable reason to change together.
- The entry file handles route integration, provider wiring, close behavior, and feature surface mounting. The composition owner handles high-level workflow branching, and the closest visual owner handles section branching.
- State-heavy wizards, drawers, modals, and secondary workflows can be a small feature surface: an entry file, one feature-local state file when Jotai is actually needed, and shallow `ui/` owners that match real visual regions.
- The entry file handles route integration, provider wiring, close behavior, and surface mounting. The composition owner handles high-level workflow branching. The closest visual owner handles section branching.
- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data; TanStack Query deduplicates and shares cache.
- Pass stable domain identity across boundaries. Do not forward derived presentation state when the receiver can derive it from its own data source.
- A component that owns a visual surface should also own data access, loading, empty, and error states for content rendered inside it unless a parent truly coordinates that state.
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in Query and API flow.
- Do not replace prop drilling with one large view-model hook threaded through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it.
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child, menu, or row own the action.
## Ownership
## Feature-Scoped Jotai
- Put local state, queries, mutations, handlers, and derived UI data in the lowest component that uses them. Extract a purpose-built owner component only when the logic has no natural home.
- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data. Do not hoist a query only because it is duplicated; TanStack Query handles deduplication and cache sharing.
- Hoist state, queries, or callbacks to a parent only when the parent consumes the data, coordinates shared loading/error/empty UI, needs one consistent snapshot, or owns a workflow spanning children.
- Pass stable domain identity across boundaries; avoid forwarding derived presentation state when the receiver can derive it from its own data source. A component that owns a visual surface should also own the data access, loading, empty, and error states for content rendered inside it unless a parent truly coordinates that state.
- Loading states for visual surfaces should use skeleton placeholders scoped to the content that is actually loading, with shape, density, and dimensions close to the final UI. Avoid generic loading text or centered spinners for page sections, cards, lists, tables, forms, and drawers; reserve spinners for small inline busy indicators such as an in-progress status icon.
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow.
- Do not replace prop drilling with one top-level hook that returns a large view model and then thread that object through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it, or use feature-scoped Jotai atoms for simple shared form/UI state when siblings need the same source of truth.
- When using feature-scoped Jotai state for a form, drawer, or other secondary surface, scope the store to that surface instance when stale cross-instance state is possible. Initialize stable config at the owning boundary, then let descendants read only the atoms or purpose-named hooks they actually need.
- For Jotai-backed surfaces, put shared query atoms, mutation atoms, derived state, and write actions in the feature state file when they coordinate multiple descendants. Do not create a query or mutation atom only because the surrounding feature uses Jotai. If the query or mutation does not read atom state, feed another atom, or participate in shared workflow orchestration, use `useQuery` or `useMutation` directly at the lowest owner.
- For repeated row/menu action surfaces that need reset, hydrate the stable identity at the surface entry and scope only the primitives that truly need per-instance reset, such as open flags, drafts, or selected local options.
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action.
- Default to uncontrolled form and DOM state. Add controlled props or atom-backed drafts only when live cross-component reactions, multi-step persistence, or external synchronization require them.
## Feature-Scoped Jotai State
- A module's feature-local state lives in one state file for Jotai-backed features: primitive atoms, shared query atoms, derived atoms, write-only action atoms, shared mutation atoms, submission orchestration, provider exports, and optional scope configuration.
- Keep synchronous UI state local when one component owns it, even inside Jotai-backed features. Dialog open flags, menu/popover visibility, confirmation visibility, form/input drafts, and selected local options usually belong in component state.
- Do not put simple form drafts in Jotai atoms. For edit/create forms whose fields are only read at submit time, use uncontrolled `@langgenius/dify-ui/form` and `@langgenius/dify-ui/field` controls with `defaultValue`, browser/form validation, and keyed remounts for query-backed initial values.
- Promote form state to Jotai only when another component must react to in-progress field changes, the draft must survive unmount/remount within the same scoped workflow, or multiple steps/surfaces share the same editable draft before submit.
- Keep submit-time normalization, dirty checks, and payload shaping beside the form submit handler. Do not create form atoms, field atoms, or derived can-save atoms only to mirror uncontrolled form values or disable a submit button.
- In Jotai-backed feature surfaces, never hand-roll async loading, error, or in-flight guards with `useState` or `useRef`. For async work that depends on atom state, feeds derived atoms, or participates in shared submission orchestration, model the work with `atomWithQuery` or `atomWithMutation`; write atoms should only update the inputs that drive those atoms. For component-owned remote work that does not participate in atom state, use TanStack Query hooks directly.
- Row-local async state should belong to the row owner. Use `useQuery` or `useMutation` directly for row actions that do not depend on atom state and are not consumed by other atoms. Use a per-instance query or mutation atom only when the row action participates in a Jotai-backed shared workflow or needs atom-scoped reset semantics.
- Promote UI state to an atom only when siblings need the same source of truth, the value drives a query or mutation atom, a parent workflow coordinates the state, or the state intentionally persists across hidden or unmounted descendants within a scoped surface.
- Reflect atom-backed surface-wide locks or invariants in every affected trigger. If only one row, menu, or dialog should be disabled, keep the pending or lock scope local to that row, menu, or dialog with the lowest-owner query/mutation hook unless it genuinely participates in shared atom state.
- Atom order in the state file follows the dependency graph: types/constants, editable primitives, query atoms, query-data derived atoms, readiness/business derived atoms, write actions, mutation atoms, submission orchestration, provider exports.
- Derived atom names read as business facts. Write atom names read as user or workflow commands.
- UI components read and write the exact atom they use with `useAtomValue` or `useSetAtom`. Repeated workflow semantics live in named derived atoms or write atoms.
- Non-query derived atoms return a narrow value with a clear domain name; avoid pass-through aliases or bundling unrelated UI facts. Query atoms expose the TanStack Query result object so loading, error, fetch, and pagination state stay attached to the query contract.
- Write-only atoms own synchronous state transitions that update multiple primitives, reset dependent state, or advance the workflow. Async work with loading, error, caching, retry, stale-result, or in-flight concerns should be modeled as query or mutation atoms, with write atoms only changing the inputs that drive them.
- Avoid feature hooks that aggregate form values, query results, derived state, and commands for sibling components. Prefer named derived atoms and write atoms so UI components read the exact shared fact or command they need.
- When a form library owns validation, keep submit orchestration in feature state when post-submit result or error state is shared by the surface. Avoid duplicating validation gates or request shaping in UI hooks.
- `jotai-tanstack-query` atoms use the same QueryClient as the React Query provider. Query atoms belong in feature state only when they need atom inputs, provide data to derived atoms, or coordinate a shared Jotai-backed workflow.
- Jotai scope is an optional instance-isolation tool for secondary surfaces with independent local state. Query and mutation atoms keep shared cache behavior through the shared QueryClient.
- Do not put `atomWithQuery`, `atomWithInfiniteQuery`, `atomWithMutation`, or broad derived orchestration atoms in a `ScopeProvider` just to reset a surface. Scoped derived atoms implicitly scope their dependencies, which can duplicate query client access and break shared invalidation. Leave query/mutation atoms unscoped; let them read scoped primitive inputs.
- Scope providers should list resettable primitive atoms and explicit hydration tuples. If a derived atom must be scoped, confirm that every dependency it implicitly scopes is meant to be private to that surface.
- For scoped primitives that are always hydrated by a `ScopeProvider` tuple, prefer `atomWithLazy<T>(() => { throw new Error(...) })` when consumers should see a non-null type. This keeps missing provider hydration as a runtime invariant without leaking `T | undefined` or adding pass-through "required" derived atoms only for narrowing.
- Keep independent dialog lifecycles separate. Avoid a single discriminated "current action dialog" atom when edit, delete, and other dialogs have their own open state, loading guard, or reset behavior.
- Route-derived stable identities that do not need instance reset or scoped isolation can be hydrated at the route or layout boundary into a feature route atom. Use scoped atoms only when stale cross-instance state or per-surface reset semantics are needed.
- A Jotai-backed feature has one feature-local state file for shared primitive atoms, query atoms, derived atoms, write-only actions, mutation atoms, submission orchestration, provider exports, and optional scope configuration.
- Keep component-owned synchronous UI state local even inside Jotai features: dialog open flags, menus/popovers, confirmations, field drafts, and selected local options usually belong in component state.
- Use uncontrolled `@langgenius/dify-ui/form` and `@langgenius/dify-ui/field` controls for edit/create forms whose fields are read only at submit time. Initialize query-backed defaults with `defaultValue` and keyed remounts.
- Promote form state to atoms only when another component must react to in-progress values, a draft must survive unmount/remount in the scoped workflow, or multiple steps share the same editable draft before submit.
- Treat `useParams`, route args, and `nuqs` query state as framework-owned state. When atom logic needs those values, hydrate primitive atoms at the route or surface boundary, such as with `useHydrateAtoms(..., { dangerouslyForceHydrate: true })`; keep URL updates in the route/query-state APIs instead of write atoms.
- For async work tied to atom state, use `atomWithQuery` or `atomWithMutation`; write atoms should update only the inputs that drive those atoms. This applies to pure frontend async work as well as network requests, so do not hand-roll loading/error/in-flight state with `useState` or `useRef` for atom-orchestrated async behavior. For component-owned remote work, use `useQuery` or `useMutation` directly.
- Row-local async state belongs to the row owner unless it participates in a shared Jotai workflow or needs atom-scoped reset semantics.
- Leave query and mutation atoms unscoped so they keep shared QueryClient cache and invalidation behavior. Scope resettable primitives and explicit hydration tuples; scope a derived atom only when every dependency should be private to that surface.
- For scoped primitives that are always hydrated by `ScopeProvider`, prefer `atomWithLazy<T>(() => { throw new Error(...) })` when consumers should see a non-null type.
- Order state files by dependency graph: types/constants, primitives, query atoms, query-data derived atoms, business/readiness derived atoms, write actions, mutation atoms, submission orchestration, provider exports.
- Name derived atoms as business facts and write atoms as user or workflow commands. Components should read or write the exact atom they need with `useAtomValue` or `useSetAtom`.
- Menu/dialog `open` state usually stays local, but a scoped atom is acceptable when a composed menu plus secondary surface would otherwise pass confusing `open`/`onClose` props through unrelated layers. Scope that primitive with the surface instance so reset behavior stays local.
- Keep independent dialog lifecycles separate. Avoid one discriminated "current action dialog" atom when dialogs have separate open state, loading guards, or reset behavior.
## Components, Props, And Types
@ -74,86 +62,56 @@ Use this as the decision guide for React/TypeScript component structure. Existin
- Prefer named exports. Use default exports only where the framework requires them, such as Next.js route files.
- Type simple one-off props inline. Use a named `Props` type only when reused, exported, complex, or clearer.
- Use API-generated or API-returned types at component boundaries. Keep small UI conversion helpers and one-off UI extensions beside the component that needs them.
- Do not create type aliases that only rename another type. Use an alias only when it encodes a real UI concept, refinement, or reusable local contract.
- Name values by their domain role and backend API contract, and keep that name stable across the call chain, especially persistent IDs and route params. Normalize framework or route params at the boundary.
- Keep fallback and invariant checks at the lowest component that already handles that state; avoid defensive fallbacks that mask impossible states.
- Do not extract fallback helpers whose only behavior is hiding missing display data. The component that renders the surface owns the empty, disabled, hidden, or placeholder state.
- Do not create type aliases that only rename another type. Use aliases only for real UI concepts, refinements, or reusable local contracts.
- Name values by their domain role and backend API contract, especially persistent IDs and route params. Normalize framework or route params at the boundary.
- Put fallback and invariant checks in the lowest component that already handles that state. Do not extract helpers whose only behavior is hiding missing display data.
## Generated API Contracts
## Generated API And Nullable Data
- Treat generated contracts as authoritative at API, query, mutation, cache, and service boundaries. For enterprise APIs, use `packages/contracts/generated/enterprise/*`.
- Do not hand-write local request/response/reply/page/cache-data types that mirror generated DTOs. Import or infer the generated type.
- Do not widen generated fields or enums for compatibility. Normalize legacy input at the boundary, then return the generated field type.
- Do not repair generated or API-returned contract fields in components unless the API contract or product requirement says they need normalization. Treat enums, statuses, and presence flags as exact contract values.
- Use generated enum objects and union types directly in props, comparisons, status logic, and i18n keys. Do not add local enum constants or parallel frontend enum/status layers unless they model real product state not represented by the API. Presentation-only tone maps should be keyed by the generated enum.
- Normalize or coerce only at a real boundary, such as user-entered forms, search, URL/query params, file names, DOM IDs, or legacy adapters. Preserve user-entered values when whitespace or formatting can be meaningful.
- Do not coerce nullable or optional API strings to `''` in query, derived model, or payload-building code. Keep `undefined` or `null` until the final boundary that requires a string.
- Do not use `value || undefined` for mutation payload fields where an empty string means "clear this value". Trim or normalize at the form boundary, then preserve `''` when the API contract treats it as an intentional update.
- Local UI models are fine for presentation, form state, select options, or guarded required-field refinements. Name them as UI concepts, not generated DTO mirrors.
- Required-value refinements are allowed only after same-branch filtering or early return. Prefer nullable-tolerant props for render-only data.
- When a component needs a stricter shape than a generated DTO, refine once at the API/query-to-UI boundary into a purpose-named UI type instead of hiding missing fields with generic fallback or coercion helpers.
## Nullable API Data
- Prefer nullable-tolerant call boundaries. Pass API-returned types through for render-only rows, and let the component render fallback, disabled state, or nothing.
- Narrow only where a real value is required, such as mutation params, route hrefs, select values, or query input. Build that target model with `flatMap`, a local loop, or an early return so the required value is captured in the same branch.
- If design says a field is the display value, use that field. Only the final component should decide whether a nullable display value renders a placeholder, hides content, or disables an action.
- Do not wrap required arrays or fields in null-fallback helpers. Use empty collection fallbacks only for not-yet-loaded query data or genuinely nullable collections at the owning render boundary.
- Do not drop rows only to satisfy props or React keys; use a stable fallback key when possible.
- Use conditional spreads or explicit pushes for conditional array items instead of `undefined` placeholders followed by a narrowing filter.
- Avoid truthiness type guards, `filter(Boolean)`, `filter(item => item.id)`, and `!` after those filters.
- Use type guards only for meaningful domain or runtime validation, such as enum membership, object shape, or a reusable business invariant.
- Do not hand-write DTO mirrors, widen generated fields/enums, or add parallel frontend enum/status layers unless they model product state not represented by the API.
- Use generated enum objects and union types directly in props, comparisons, status logic, and i18n keys. Presentation-only tone maps should be keyed by generated enums.
- Normalize or coerce only at real boundaries: user-entered forms, search, URL/query params, file names, DOM IDs, or legacy adapters.
- Do not coerce nullable or optional API strings to `''` in query, derived model, or payload-building code. Keep `null` or `undefined` until the final boundary requiring a string.
- Do not use `value || undefined` for mutation fields where `''` means "clear this value". Trim or normalize at the form boundary, then preserve intentional empty strings.
- Prefer nullable-tolerant render props for API-returned rows. Narrow only where a real value is required, such as mutation params, route hrefs, select values, query input, or required React keys.
- Build required values in the same branch that proves them, using `flatMap`, a local loop, or an early return. Avoid truthiness guards, `filter(Boolean)`, `filter(item => item.id)`, and `!` after filters.
- Use conditional spreads or explicit pushes for conditional array items instead of `undefined` placeholders followed by narrowing filters.
- Empty collection fallbacks are for not-yet-loaded query data or genuinely nullable collections at the owning render boundary, not for hiding required API fields.
## Queries And Mutations
- Keep `web/contract/*` as the single source of truth for API shape; follow existing domain/router patterns and the `{ params, query?, body? }` input shape.
- Consume queries directly with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`.
- Do not promote a query or mutation to an atom just because the feature already has a state file. Use `atomWithQuery` or `atomWithMutation` only when the query/mutation reads atom state, is consumed by another atom, or is part of shared workflow orchestration.
- In `atomWithQuery` and `atomWithInfiniteQuery`, return generated `queryOptions()` or `infiniteOptions()` directly. Pass `enabled`, `retry`, `placeholderData`, `select`, and pagination options into that call instead of spreading generated options into a hand-built object.
- When prefetch and render consume the same server request, extract local query options or a query-options atom so `queryClient.prefetchQuery(...)` and `useQuery`/`atomWithQuery` share the exact generated query options.
- In `atomWithMutation`, return generated `mutationOptions()` directly when using generated clients. Put request shaping and submit orchestration in write atoms; do not rebuild mutation option objects just to pass through the generated mutation function.
- For custom query functions that do not come from generated clients, wrap the options object with TanStack `queryOptions(...)` so query atoms still return a query options contract.
- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename `queryOptions()` or `mutationOptions()`. Extract a small `queryOptions` helper only when repeated call-site options justify it.
- Keep feature hooks for real orchestration, workflow state, or shared domain behavior.
- For TanStack cache data, use generated or query-derived types; do not create local wrappers for `getQueryData` or `getQueriesData`.
- For generated oRPC `queryOptions()` / `infiniteOptions()`, keep returning the generated options directly. When required input is missing, use a whole-input branch such as `input: condition ? validInput : skipToken` together with `enabled: Boolean(condition)` so no request runs and no fake payload is built.
- Do not put `skipToken` inside a nested placeholder payload, such as `{ params: { appInstanceId: skipToken } }`. Do not create hand-written "missing queryOptions" objects or coerce required IDs to `''`.
- Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))` when the mutation is owned by one component, menu, dialog, or row and its pending/error state is not consumed by feature atoms. In Jotai-backed workflow orchestration, expose mutations from feature state with `atomWithMutation` so pending/error state stays attached to the mutation atom. For component-owned custom mutation functions, use `useMutation(mutationOptions(...))` at the owner.
- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`; components may add UI feedback callbacks, but should not own shared invalidation rules.
- Component or atom mutation callbacks can handle local UI feedback such as toasts, closing dialogs, or navigation. They should not replace shared invalidation or add local cache patches for shared server state.
- For overlays that may open a heavier secondary surface, prefetch server data from the trigger/menu open event with `queryClient.prefetchQuery(queryOptions)` when the primitive exposes `onOpenChange`. Do not mount a hidden component or subscribe to a query only to warm the cache. Do not make an otherwise uncontrolled menu controlled only for prefetching.
- Keep `web/contract/*` as the API shape source of truth and follow the `{ params, query?, body? }` input shape.
- Consume generated queries with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`.
- Consume owner-local mutations with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))` when pending/error state is not consumed by feature atoms.
- In `atomWithQuery`, `atomWithInfiniteQuery`, and `atomWithMutation`, return generated `queryOptions()`, `infiniteOptions()`, or `mutationOptions()` directly. Pass `enabled`, `retry`, `placeholderData`, `select`, and pagination options into the generated call instead of spreading options into a hand-built object.
- For generated oRPC options with missing required input, branch the whole input with `input: condition ? validInput : skipToken` and `enabled: Boolean(condition)`. Never place `skipToken` inside a nested placeholder payload or coerce required IDs to `''`.
- When prefetch and render use the same request, extract local query options or a query-options atom so `prefetchQuery` and `useQuery`/`atomWithQuery` share the exact options.
- For custom query or mutation functions, wrap options with TanStack `queryOptions(...)` or `mutationOptions(...)`.
- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename generated options. Keep feature hooks for real orchestration, workflow state, or shared domain behavior.
- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`. Component or atom callbacks may handle local toasts, closing dialogs, and navigation, but should not replace shared invalidation or patch shared server state locally.
- For overlays that may open heavier secondary content, prefetch from the trigger/menu open event with `queryClient.prefetchQuery(queryOptions)` when `onOpenChange` is available. Do not mount hidden subscribers just to warm cache.
- Do not use deprecated `useInvalid` or `useReset`.
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required, and wrap awaited calls in `try/catch`.
## Component Boundaries
## Boundaries And Overlays
- Use the first level below a page or tab to organize independent page sections when it adds real structure. This layer is layout/semantic first, not automatically the data owner.
- Treat component names, semantic roles, and user- or design-marked visual regions as boundary constraints. Do not expand a child component's responsibility just because its data is useful nearby; keep adjacent UI as a sibling owner or introduce a correctly named broader owner.
- Split deeper components by the data and state each layer actually needs. Each component should access only necessary data, and ownership should stay at the lowest consumer.
- Use the first level below a page or tab to organize independent page sections when it adds structure. This layer is layout/semantic first, not automatically the data owner.
- Treat component names, semantic roles, and user- or design-marked visual regions as boundary constraints. Keep adjacent UI as a sibling owner or introduce a correctly named broader owner.
- Keep cohesive forms, menu bodies, and one-off helpers local unless they need their own state, reuse, or semantic boundary.
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow.
- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment.
- When a dialog, dropdown, or popover component already accepts controlled `open` state, mount the surface unconditionally unless unmounting is required for performance or reset semantics. Use keyed scope or local state reset for reset behavior instead of `{open && <Surface />}` wrappers.
- When opening a dialog from a menu item, keep the menu and dialog as sibling surfaces. Let the menu item command open the dialog through local state or scoped atoms, and mount the dialog outside the menu popup content. Avoid wrapping menu items with dialog triggers when the menu primitive already owns item activation and dismissal behavior.
- For dialogs and alert dialogs, keep the root component responsible for `open` wiring and put query/mutation hooks inside the content component when the work should only mount after the overlay opens. Do not put closed-surface remote work in the root just because the root owns the open atom.
- Prefer uncontrolled overlay roots when the library can own their open state. Use `onOpenChange` for side effects such as prefetching, and CSS/data selectors for visual open-state styling instead of adding controlled state only for observation.
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
- Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, children-as-pass-through composition, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook, forwards props, or passes trigger/content through to one child, move the logic into that child or make the wrapper own a real surface.
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component when hidden content would obscure the parent.
- Preserve composability by separating behavior ownership from placement ownership: an action can own trigger/open/menu content while the caller owns slots, offsets, and alignment.
- When a dialog, dropdown, or popover accepts controlled `open`, mount it unconditionally unless unmounting is required for performance or reset semantics. Use keyed scope or local state reset instead of `{open && <Surface />}` wrappers.
- When opening a dialog from a menu item, keep the menu and dialog as sibling surfaces. Let the menu command open the dialog, and mount the dialog outside menu popup content.
- For dialogs and alert dialogs, keep the root responsible for `open` wiring and put query/mutation hooks inside the content component when work should mount only after the overlay opens.
- Prefer uncontrolled overlay roots when the library can own open state. Use `onOpenChange` for side effects and CSS/data selectors for open-state styling.
- Avoid wrapper DOM unless it provides layout, semantics, accessibility, state ownership, or library integration. Avoid shallow wrappers, hook-to-props adapters, layout-only render props, children pass-through wrappers, and prop renaming unless they add real behavior or a real boundary.
## You Might Not Need An Effect
- Use Effects only to synchronize with external systems such as browser APIs, non-React widgets, subscriptions, timers, analytics that must run because the component was shown, or imperative DOM integration.
- Do not use Effects to transform props or state for rendering. Calculate derived values during render, and use `useMemo` only when the calculation is actually expensive.
- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known.
- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render.
- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary.
- For forms initialized from query data, prefer keyed remounts or surface-entry hydration of form/field atoms over an Effect that copies query data into form state.
- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components.
- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow.
## Navigation And Performance
## Effects, Navigation, And Performance
- Use Effects only to synchronize with external systems. Do not use Effects to transform props/state for rendering, handle user actions, copy state, reset state from props, or fetch data.
- For forms initialized from query data, prefer keyed remounts or surface-entry atom hydration over Effects that copy query data into form state.
- Prefer framework data APIs or TanStack Query for data fetching.
- Prefer `Link` for normal navigation. Use router APIs only for command-flow side effects such as mutation success, guarded redirects, or form submission.
- Before reaching for `memo`, first try moving changing state down to the smallest component that actually uses it so unrelated sibling trees stay untouched.
- If changing state must wrap other content, lift the unchanged content up and pass it as `children` so the stateful wrapper can update without React visiting that subtree.
- Before using `memo`, move changing state down to the smallest component that uses it. If state must wrap stable content, lift the stable content up and pass it as `children`.
- Avoid `memo`, `useMemo`, and `useCallback` unless there is a clear performance reason.

View File

@ -105,6 +105,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: vp run knip
- name: Web dead code check production
if: steps.changed-files.outputs.any_changed == 'true'
run: vp run knip:production
ts-common-style:
name: TS Common
runs-on: depot-ubuntu-24.04

View File

@ -36,6 +36,9 @@ FILES_ACCESS_TIMEOUT=300
# Collaboration mode toggle
ENABLE_COLLABORATION_MODE=true
# Learn app feature toggle
ENABLE_LEARN_APP=true
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60

View File

@ -133,8 +133,9 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
password=new_password,
language=language,
create_workspace_required=False,
session=db.session,
)
TenantService.create_owner_tenant_if_not_exist(account, name)
TenantService.create_owner_tenant_if_not_exist(account, name, session=db.session)
click.echo(
click.style(

View File

@ -1073,6 +1073,12 @@ class MailConfig(BaseSettings):
default=None,
)
class HomepageConfig(BaseSettings):
"""
Configuration for homepage feature toggles exposed through system features.
"""
ENABLE_TRIAL_APP: bool = Field(
description="Enable trial app",
default=False,
@ -1083,6 +1089,11 @@ class MailConfig(BaseSettings):
default=False,
)
ENABLE_LEARN_APP: bool = Field(
description="Enable Learn App",
default=True,
)
class RagEtlConfig(BaseSettings):
"""
@ -1489,6 +1500,7 @@ class FeatureConfig(
EndpointConfig,
FileAccessConfig,
FileUploadConfig,
HomepageConfig,
HttpConfig,
InnerAPIConfig,
IndexingConfig,

View File

@ -19,6 +19,7 @@ from controllers.console.wraps import (
rbac_permission_required,
setup_required,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
Annotation,
@ -388,7 +389,9 @@ class AnnotationUpdateDeleteApi(Resource):
update_args["answer"] = args.answer
if args.question is not None:
update_args["question"] = args.question
annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id))
annotation = AppAnnotationService.update_app_annotation_directly(
update_args, str(app_id), str(annotation_id), db.session
)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@ -398,7 +401,7 @@ class AnnotationUpdateDeleteApi(Resource):
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_EDIT)
@console_ns.response(204, "Annotation deleted successfully")
def delete(self, app_id: UUID, annotation_id: UUID):
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id), db.session)
return "", 204

View File

@ -17,6 +17,7 @@ from controllers.console.wraps import (
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.member_fields import AccountWithRole
from libs.helper import build_avatar_url, dump_response, to_timestamp
@ -489,7 +490,7 @@ class WorkflowCommentMentionUsersApi(Resource):
current_tenant = current_user.current_tenant # need the tenant object here
if current_tenant is None:
raise ValueError("current tenant is required")
members = TenantService.get_tenant_members(current_tenant)
members = TenantService.get_tenant_members(current_tenant, session=db.session)
users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = WorkflowCommentMentionUsersPayload(users=users)
return response.model_dump(mode="json"), 200

View File

@ -89,7 +89,9 @@ class ActivateCheckApi(Resource):
workspaceId = args.workspace_id
token = args.token
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
invitation = RegisterService.get_invitation_with_case_fallback(
workspaceId, args.email, token, session=db.session
)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
@ -137,7 +139,9 @@ class ActivateApi(Resource):
args = ActivatePayload.model_validate(console_ns.payload)
normalized_request_email = args.email.lower() if args.email else None
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
invitation = RegisterService.get_invitation_with_case_fallback(
args.workspace_id, args.email, args.token, session=db.session
)
if invitation is None:
raise AlreadyActivateError()
@ -184,6 +188,6 @@ class ActivateApi(Resource):
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
TenantService.switch_tenant(account, tenant.id)
TenantService.switch_tenant(account, tenant.id, session=db.session)
return {"result": "success"}

View File

@ -187,7 +187,7 @@ class EmailRegisterResetApi(Resource):
timezone=args.timezone,
language=args.language,
)
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}
@ -206,6 +206,7 @@ class EmailRegisterResetApi(Resource):
password=password,
interface_language=get_valid_language(language),
timezone=timezone,
session=db.session,
)
except AccountRegisterError:
raise AccountInFreezeError()

View File

@ -198,10 +198,10 @@ class ForgotPasswordResetApi(Resource):
# Create workspace if needed
if (
not TenantService.get_join_tenants(account)
not TenantService.get_join_tenants(account, session=db.session)
and FeatureService.get_system_features().is_allow_create_workspace
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)

View File

@ -119,7 +119,9 @@ class LoginApi(Resource):
invite_token = args.invite_token
invitation_data: InvitationDetailDict | None = None
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
invitation_data = RegisterService.get_invitation_with_case_fallback(
None, request_email, invite_token, session=db.session
)
if invitation_data is None:
invite_token = None
@ -145,7 +147,7 @@ class LoginApi(Resource):
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
tenants = TenantService.get_join_tenants(account, session=db.session)
if len(tenants) == 0:
system_features = FeatureService.get_system_features()
@ -157,7 +159,7 @@ class LoginApi(Resource):
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(normalized_email)
# Create response with cookies instead of returning tokens in body
@ -291,7 +293,7 @@ class EmailCodeLoginApi(Resource):
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
tenants = TenantService.get_join_tenants(account, session=db.session)
if not tenants:
workspaces = FeatureService.get_system_features().license.workspaces
if not workspaces.is_available():
@ -299,7 +301,7 @@ class EmailCodeLoginApi(Resource):
if not FeatureService.get_system_features().is_allow_create_workspace:
raise NotAllowedCreateWorkspace()
else:
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
TenantService.create_tenant_member(new_tenant, account, db.session, role="owner")
account.current_tenant = new_tenant
tenant_was_created.send(new_tenant)
@ -311,6 +313,7 @@ class EmailCodeLoginApi(Resource):
name=user_email,
interface_language=get_valid_language(language),
timezone=args.timezone,
session=db.session,
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
@ -319,7 +322,7 @@ class EmailCodeLoginApi(Resource):
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
token_pair = AccountService.login(account, session=db.session, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(user_email)
# Create response with cookies instead of returning tokens in body
@ -343,7 +346,7 @@ class RefreshTokenApi(Resource):
return {"result": "fail", "message": "No refresh token provided"}, 401
try:
new_token_pair = AccountService.refresh_token(refresh_token)
new_token_pair = AccountService.refresh_token(refresh_token, session=db.session)
# Create response with new cookies
response = make_response({"result": "success"})
@ -358,22 +361,22 @@ class RefreshTokenApi(Resource):
def _get_account_with_case_fallback(email: str):
account = AccountService.get_user_through_email(email)
account = AccountService.get_user_through_email(email, session=db.session)
if account or email == email.lower():
return account
return AccountService.get_user_through_email(email.lower())
return AccountService.get_user_through_email(email.lower(), session=db.session)
def _authenticate_account_with_case_fallback(
original_email: str, normalized_email: str, password: str, invite_token: str | None
):
try:
return AccountService.authenticate(original_email, password, invite_token)
return AccountService.authenticate(original_email, password, invite_token, session=db.session)
except services.errors.account.AccountPasswordError:
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)
return AccountService.authenticate(normalized_email, password, invite_token, session=db.session)
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:

View File

@ -195,7 +195,7 @@ class OAuthCallback(Resource):
db.session.commit()
try:
TenantService.create_owner_tenant_if_not_exist(account)
TenantService.create_owner_tenant_if_not_exist(account, session=db.session)
except Unauthorized:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
@ -206,6 +206,7 @@ class OAuthCallback(Resource):
token_pair = AccountService.login(
account=account,
session=db.session,
ip_address=extract_remote_ip(request),
)
@ -240,12 +241,12 @@ def _generate_account(
oauth_new_user = False
if account:
tenants = TenantService.get_join_tenants(account)
tenants = TenantService.get_join_tenants(account, session=db.session)
if not tenants:
if not FeatureService.get_system_features().is_allow_create_workspace:
raise WorkSpaceNotAllowedCreateError()
else:
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
TenantService.create_tenant_member(new_tenant, account, db.session, role="owner")
account.current_tenant = new_tenant
tenant_was_created.send(new_tenant)
@ -272,9 +273,10 @@ def _generate_account(
provider=provider,
language=interface_language,
timezone=timezone,
session=db.session,
)
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)
AccountService.link_account_integrate(provider, user_info.id, account, session=db.session)
return account, oauth_new_user

View File

@ -175,7 +175,7 @@ class InstalledAppsListApi(Resource):
if current_user.current_tenant is None:
raise ValueError("current_user.current_tenant must not be None")
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant, session=db.session)
installed_app_list: list[dict[str, Any]] = []
for installed_app, app_model in installed_apps:
installed_app_list.append(

View File

@ -50,7 +50,7 @@ def get_init_status() -> InitStatusResponse:
@only_edition_self_hosted
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
"""Validate initialization password."""
tenant_count = TenantService.get_tenant_count()
tenant_count = TenantService.get_tenant_count(session=db.session)
if tenant_count > 0:
raise AlreadySetupError()

View File

@ -79,7 +79,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
if get_setup_status():
raise AlreadySetupError()
tenant_count = TenantService.get_tenant_count()
tenant_count = TenantService.get_tenant_count(session=db.session)
if tenant_count > 0:
raise AlreadySetupError()
@ -94,6 +94,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
password=payload.password,
ip_address=extract_remote_ip(request),
language=payload.language,
session=db.session,
)
return SetupResponse(result="success")

View File

@ -4,6 +4,7 @@ from typing import cast
from flask import Request as FlaskRequest
from extensions.ext_database import db
from extensions.ext_socketio import sio
from libs.passport import PassportService
from libs.token import extract_access_token
@ -43,7 +44,7 @@ def socket_connect(sid, environ, auth):
return False
with sio.app.app_context():
user = AccountService.load_logged_in_account(account_id=user_id)
user = AccountService.load_logged_in_account(account_id=user_id, session=db.session)
if not user:
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
return False

View File

@ -328,7 +328,7 @@ class AccountNameApi(Resource):
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
updated_account = AccountService.update_account(current_user, session=db.session, name=args.name)
return _serialize_account(updated_account)
@ -374,7 +374,7 @@ class AccountAvatarApi(Resource):
payload = console_ns.payload or {}
args = AccountAvatarPayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
updated_account = AccountService.update_account(current_user, session=db.session, avatar=args.avatar)
return _serialize_account(updated_account)
@ -391,7 +391,9 @@ class AccountInterfaceLanguageApi(Resource):
payload = console_ns.payload or {}
args = AccountInterfaceLanguagePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
updated_account = AccountService.update_account(
current_user, session=db.session, interface_language=args.interface_language
)
return _serialize_account(updated_account)
@ -408,7 +410,9 @@ class AccountInterfaceThemeApi(Resource):
payload = console_ns.payload or {}
args = AccountInterfaceThemePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
updated_account = AccountService.update_account(
current_user, session=db.session, interface_theme=args.interface_theme
)
return _serialize_account(updated_account)
@ -425,7 +429,7 @@ class AccountTimezoneApi(Resource):
payload = console_ns.payload or {}
args = AccountTimezonePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
updated_account = AccountService.update_account(current_user, session=db.session, timezone=args.timezone)
return _serialize_account(updated_account)
@ -443,7 +447,8 @@ class AccountPasswordApi(Resource):
args = AccountPasswordPayload.model_validate(payload)
try:
AccountService.update_account_password(current_user, args.password, args.new_password)
assert args.password is not None
AccountService.update_account_password(current_user, args.password, args.new_password, session=db.session)
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
@ -731,7 +736,7 @@ class ChangeEmailResetApi(Resource):
if AccountService.is_account_in_freeze(normalized_new_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(normalized_new_email):
if not AccountService.check_email_unique(normalized_new_email, session=db.session):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@ -755,7 +760,9 @@ class ChangeEmailResetApi(Resource):
# legitimately verified token.
AccountService.revoke_change_email_token(args.token)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
updated_account = AccountService.update_account_email(
current_user, email=normalized_new_email, session=db.session
)
AccountService.send_change_email_completed_notify_email(
email=normalized_new_email,
@ -775,6 +782,6 @@ class CheckEmailUnique(Resource):
normalized_email = args.email.lower()
if AccountService.is_account_in_freeze(normalized_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(normalized_email):
if not AccountService.check_email_unique(normalized_email, session=db.session):
raise EmailAlreadyInUseError()
return {"result": "success"}

View File

@ -186,7 +186,7 @@ class MemberListApi(Resource):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
members = TenantService.get_tenant_members(current_user.current_tenant, session=db.session)
if dify_config.RBAC_ENABLED:
member_ids = [member.id for member in members]
member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get(
@ -273,6 +273,7 @@ class MemberInviteEmailApi(Resource):
language=interface_language,
role=invitee_role,
inviter=inviter,
session=db.session,
)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
@ -317,7 +318,9 @@ class MemberCancelInviteApi(Resource):
abort(404)
else:
try:
TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
TenantService.remove_member_from_tenant(
current_user.current_tenant, member, current_user, session=db.session
)
except services.errors.account.CannotOperateSelfError as e:
return {"code": "cannot-operate-self", "message": str(e)}, 400
except services.errors.account.NoPermissionError as e:
@ -360,7 +363,9 @@ class MemberUpdateRoleApi(Resource):
try:
assert member is not None, "Member not found"
TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user)
TenantService.update_member_role(
current_user.current_tenant, member, new_role, current_user, session=db.session
)
except services.errors.account.CannotOperateSelfError as e:
return {"code": "cannot-operate-self", "message": str(e)}, 400
except services.errors.account.NoPermissionError as e:
@ -387,7 +392,7 @@ class DatasetOperatorMemberListApi(Resource):
def get(self, current_user: Account):
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
members = TenantService.get_dataset_operator_members(current_user.current_tenant, session=db.session)
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
@ -413,7 +418,7 @@ class SendOwnerTransferEmailApi(Resource):
# check if the current user is the owner of the workspace
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
raise NotOwnerError()
if args.language is not None and args.language == "zh-Hans":
@ -448,7 +453,7 @@ class OwnerTransferCheckApi(Resource):
# check if the current user is the owner of the workspace
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
raise NotOwnerError()
user_email = current_user.email
@ -494,7 +499,7 @@ class OwnerTransfer(Resource):
# check if the current user is the owner of the workspace
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_owner(current_user, current_user.current_tenant):
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
raise NotOwnerError()
if current_user.id == str(member_id):
@ -516,12 +521,14 @@ class OwnerTransfer(Resource):
if not current_user.current_tenant:
raise ValueError("No current tenant")
if not TenantService.is_member(member, current_user.current_tenant):
if not TenantService.is_member(member, current_user.current_tenant, session=db.session):
raise MemberNotInTenantError()
try:
assert member is not None, "Member not found"
TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user)
TenantService.update_member_role(
current_user.current_tenant, member, "owner", current_user, session=db.session
)
AccountService.send_new_owner_transfer_notify_email(
account=member,

View File

@ -325,10 +325,10 @@ class TenantApi(Resource):
raise ValueError("No current tenant")
if tenant.status == TenantStatus.ARCHIVE:
tenants = TenantService.get_join_tenants(current_user)
tenants = TenantService.get_join_tenants(current_user, session=db.session)
# if there is any tenant, switch to the first one
if len(tenants) > 0:
TenantService.switch_tenant(current_user, tenants[0].id)
TenantService.switch_tenant(current_user, tenants[0].id, session=db.session)
tenant = tenants[0]
# else, raise Unauthorized
else:
@ -351,7 +351,7 @@ class SwitchWorkspaceApi(Resource):
# check if tenant_id is valid, 403 if not
try:
TenantService.switch_tenant(current_user, args.tenant_id)
TenantService.switch_tenant(current_user, args.tenant_id, session=db.session)
except Exception:
raise AccountNotLinkTenantError("Account not link tenant")

View File

@ -47,7 +47,7 @@ class EnterpriseWorkspace(Resource):
if account is None:
return {"message": "owner account not found."}, 404
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session)
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
tenant_was_created.send(tenant)
@ -84,7 +84,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
def post(self):
args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {})
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session)
tenant_was_created.send(tenant)

View File

@ -128,7 +128,7 @@ class WorkspaceSwitchApi(Resource):
account = _load_account(auth_data.account_id)
try:
TenantService.switch_tenant(account, workspace_id)
TenantService.switch_tenant(account, workspace_id, session=db.session)
except AccountNotLinkTenantError:
raise NotFound("workspace not found")
@ -152,7 +152,7 @@ class WorkspaceMembersApi(Resource):
@accepts(query=MemberListQuery)
def get(self, workspace_id: str, *, auth_data: AuthData, query: MemberListQuery):
tenant = _load_tenant(workspace_id)
members = TenantService.get_tenant_members(tenant)
members = TenantService.get_tenant_members(tenant, session=db.session)
total = len(members)
start = (query.page - 1) * query.limit
page_items = members[start : start + query.limit]
@ -184,6 +184,7 @@ class WorkspaceMembersApi(Resource):
language=None,
role=body.role,
inviter=inviter,
session=db.session,
)
except AccountAlreadyInTenantError as exc:
raise BadRequest(str(exc))
@ -232,7 +233,7 @@ class WorkspaceMemberApi(Resource):
raise NotFound("member not found")
try:
TenantService.remove_member_from_tenant(tenant, member, operator)
TenantService.remove_member_from_tenant(tenant, member, operator, session=db.session)
except CannotOperateSelfError as exc:
raise BadRequest(str(exc))
except NoPermissionError as exc:
@ -266,7 +267,7 @@ class WorkspaceMemberRoleApi(Resource):
raise NotFound("member not found")
try:
TenantService.update_member_role(tenant, member, body.role, operator)
TenantService.update_member_role(tenant, member, body.role, operator, session=db.session)
except CannotOperateSelfError as exc:
raise BadRequest(str(exc))
except NoPermissionError as exc:

View File

@ -10,6 +10,7 @@ from controllers.common.schema import query_params_from_model, register_response
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.annotation_fields import Annotation, AnnotationList
from fields.base import ResponseModel
@ -281,7 +282,9 @@ class AnnotationUpdateDeleteApi(Resource):
"""Update an existing annotation."""
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
annotation = AppAnnotationService.update_app_annotation_directly(
update_args, app_model.id, str(annotation_id), db.session
)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
@ -310,5 +313,5 @@ class AnnotationUpdateDeleteApi(Resource):
@edit_permission_required
def delete(self, app_model: App, annotation_id: UUID):
"""Delete an annotation."""
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id), db.session)
return "", 204

View File

@ -31,6 +31,7 @@ from core.app.entities.queue_entities import (
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueuePingEvent,
QueueReasoningChunkEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
@ -47,6 +48,7 @@ from core.app.entities.task_entities import (
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
ReasoningChunkStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
@ -571,6 +573,22 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
def _handle_reasoning_chunk_event(
self, event: QueueReasoningChunkEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle reasoning chunk events."""
# is_final with empty reasoning is still forwarded as the "thinking finished" signal
if not event.reasoning and not event.is_final:
return
yield ReasoningChunkStreamResponse(
task_id=self._application_generate_entity.task_id,
data=ReasoningChunkStreamResponse.Data(
reasoning=event.reasoning,
node_id=event.from_node_id,
is_final=event.is_final,
),
)
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
yield self._workflow_response_converter.handle_agent_log(
@ -600,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueuePingEvent: self._handle_ping_event,
QueueErrorEvent: self._handle_error_event,
QueueTextChunkEvent: self._handle_text_chunk_event,
QueueReasoningChunkEvent: self._handle_reasoning_chunk_event,
# Workflow events
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,

View File

@ -743,7 +743,8 @@ class ReasoningChunkStreamResponse(StreamResponse):
Data entity
"""
message_id: str
# chat apps set this; workflow runs have no message
message_id: str | None = None
reasoning: str
node_id: str | None = None
is_final: bool = False

View File

@ -226,7 +226,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_multimodal_documents.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
account = AccountService.load_user(document.created_by, db.session)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)

View File

@ -291,7 +291,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
attachments.append(file_document)
doc.attachments = attachments
else:
account = AccountService.load_user(document.created_by)
account = AccountService.load_user(document.created_by, db.session)
if not account:
raise ValueError("Invalid account")
doc.attachments = self._get_content_files(doc, current_user=account)

View File

@ -84,7 +84,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
if not user_id:
raise Unauthorized("Invalid Authorization token.")
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, session=db.session)
return logged_in_account
elif request.blueprint == "openapi":
# Account-branch device-flow approval routes (approve / deny /
@ -103,7 +103,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
source = decoded.get("token_source")
if source or not user_id:
return None
return AccountService.load_logged_in_account(account_id=user_id)
return AccountService.load_logged_in_account(account_id=user_id, session=db.session)
elif request.blueprint == "web":
app_code = request.headers.get(HEADER_NAME_APP_CODE)
webapp_token = extract_webapp_passport(app_code, request) if app_code else None

View File

@ -16,7 +16,7 @@ def valid_password(password):
raise ValueError("Password must contain letters and numbers, and the length must be at least 8 characters.")
def hash_password(password_str, salt_byte):
def hash_password(password_str: str, salt_byte: bytes):
dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000)
return binascii.hexlify(dk)

View File

@ -0,0 +1,26 @@
"""add cloud only flag to recommended apps
Revision ID: d9e8f7a6b5c4
Revises: c8f4a6b2d3e1
Create Date: 2026-06-23 18:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "d9e8f7a6b5c4"
down_revision = "c8f4a6b2d3e1"
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table("recommended_apps", schema=None) as batch_op:
batch_op.add_column(sa.Column("is_cloud_only", sa.Boolean(), server_default=sa.text("false"), nullable=False))
def downgrade():
with op.batch_alter_table("recommended_apps", schema=None) as batch_op:
batch_op.drop_column("is_cloud_only")

View File

@ -925,6 +925,9 @@ class RecommendedApp(TypeBase):
is_learn_dify: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false"), default=False
)
is_cloud_only: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false"), default=False
)
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
language: Mapped[str] = mapped_column(
String(255),

View File

@ -19605,6 +19605,7 @@ Model class for provider system configuration response.
| enable_email_code_login | boolean | | Yes |
| enable_email_password_login | boolean, <br>**Default:** true | | Yes |
| enable_explore_banner | boolean | | Yes |
| enable_learn_app | boolean, <br>**Default:** true | | Yes |
| enable_marketplace | boolean | | Yes |
| enable_social_oauth_login | boolean | | Yes |
| enable_trial_app | boolean | | Yes |

View File

@ -1603,6 +1603,7 @@ Default configuration for form inputs.
| enable_email_code_login | boolean | | Yes |
| enable_email_password_login | boolean, <br>**Default:** true | | Yes |
| enable_explore_banner | boolean | | Yes |
| enable_learn_app | boolean, <br>**Default:** true | | Yes |
| enable_marketplace | boolean | | Yes |
| enable_social_oauth_login | boolean | | Yes |
| enable_trial_app | boolean | | Yes |

View File

@ -2,7 +2,7 @@
name = "dify-trace-langsmith"
version = "0.0.1"
dependencies = [
"langsmith==0.8.5",
"langsmith==0.8.18",
]
description = "Dify ops tracing provider (LangSmith)."

View File

@ -103,7 +103,11 @@ dify-trace-weave = { workspace = true }
[tool.uv]
default-groups = ["storage", "tools", "vdb-all", "trace-all"]
package = false
override-dependencies = ["litellm>=1.83.10,<2.0.0", "pyarrow>=23.0.1,<24.0.0"]
override-dependencies = [
"litellm>=1.83.10,<2.0.0",
"pyarrow>=23.0.1,<24.0.0",
"cryptography>=49.0.0,<50.0.0",
]
[dependency-groups]

View File

@ -1,3 +1,10 @@
"""Account, workspace, and invitation services.
Database access in this module is caller-scoped: methods that read or mutate ORM state accept an explicit
``session`` so controllers, tasks, and tests can control transaction lifetime and avoid hidden Flask-scoped session
usage inside service logic.
"""
import base64
import json
import logging
@ -235,7 +242,7 @@ class AccountService:
)
@staticmethod
def _refresh_account_last_active(account: Account) -> None:
def _refresh_account_last_active(account: Account, session: scoped_session | Session) -> None:
now = naive_utc_now()
refresh_before = now - ACCOUNT_LAST_ACTIVE_REFRESH_INTERVAL
@ -245,12 +252,12 @@ class AccountService:
if not AccountService._should_refresh_account_last_active(account.id):
return
db.session.execute(
session.execute(
update(Account)
.where(Account.id == account.id, Account.last_active_at < refresh_before)
.values(last_active_at=now, updated_at=func.current_timestamp())
)
db.session.commit()
session.commit()
@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str):
@ -295,20 +302,20 @@ class AccountService:
side-effects (current-tenant assignment, commit) are unwanted.
``session`` is injected by the caller so this service stays free
of the Flask-scoped ``db.session`` import.
of a Flask-scoped session import.
"""
return session.get(Account, account_id)
@staticmethod
def load_user(user_id: str) -> None | Account:
account = db.session.get(Account, user_id)
def load_user(user_id: str, session: scoped_session | Session) -> None | Account:
account = session.get(Account, user_id)
if not account:
return None
if account.status == AccountStatus.BANNED:
raise Unauthorized("Account is banned.")
current_tenant = db.session.scalar(
current_tenant = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
.limit(1)
@ -316,7 +323,7 @@ class AccountService:
if current_tenant:
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = db.session.scalar(
available_ta = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
@ -328,13 +335,13 @@ class AccountService:
account.set_tenant_id(available_ta.tenant_id)
available_ta.current = True
available_ta.last_opened_at = naive_utc_now()
db.session.commit()
session.commit()
AccountService._refresh_account_last_active(account)
AccountService._refresh_account_last_active(account, session)
# NOTE: make sure account is accessible outside of a db session
# This ensures that it will work correctly after upgrading to Flask version 3.1.2
db.session.refresh(account)
db.session.close()
session.refresh(account)
session.close()
return account
@staticmethod
@ -352,10 +359,12 @@ class AccountService:
return token
@staticmethod
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
def authenticate(
email: str, password: str, invite_token: str | None = None, *, session: scoped_session | Session
) -> Account:
"""authenticate account with email and password"""
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
account = session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
raise AccountPasswordError("Invalid email or password.")
@ -378,12 +387,14 @@ class AccountService:
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()
session.commit()
return account
@staticmethod
def update_account_password(account, password, new_password):
def update_account_password(
account: Account, password: str, new_password: str, *, session: scoped_session | Session
):
"""update account password"""
if account.password and not compare_password(password, account.password, account.password_salt):
raise CurrentPasswordIncorrectError("Current password is incorrect.")
@ -400,8 +411,8 @@ class AccountService:
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.add(account)
db.session.commit()
session.add(account)
session.commit()
return account
@staticmethod
@ -413,6 +424,8 @@ class AccountService:
interface_theme: str = "light",
is_setup: bool | None = False,
timezone: str | None = None,
*,
session: scoped_session | Session,
) -> Account:
"""Create an account, preferring explicit user timezone over language-derived defaults."""
if not FeatureService.get_system_features().is_allow_register and not is_setup:
@ -458,13 +471,19 @@ class AccountService:
timezone=resolved_timezone,
)
db.session.add(account)
db.session.commit()
session.add(account)
session.commit()
return account
@staticmethod
def create_account_and_tenant(
email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None
email: str,
name: str,
interface_language: str,
password: str | None = None,
timezone: str | None = None,
*,
session: scoped_session | Session,
) -> Account:
"""Create an account and owner workspace."""
account = AccountService.create_account(
@ -473,10 +492,11 @@ class AccountService:
interface_language=interface_language,
password=password,
timezone=timezone,
session=session,
)
try:
TenantService.create_owner_tenant_if_not_exist(account=account)
TenantService.create_owner_tenant_if_not_exist(account=account, session=session)
except Exception:
# Enterprise-only side-effect should run independently from personal workspace creation.
_try_join_enterprise_default_workspace(str(account.id))
@ -536,11 +556,11 @@ class AccountService:
delete_account_task.delay(account.id)
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account):
def link_account_integrate(provider: str, open_id: str, account: Account, *, session: scoped_session | Session):
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: AccountIntegrate | None = db.session.scalar(
account_integrate: AccountIntegrate | None = session.scalar(
select(AccountIntegrate)
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
.limit(1)
@ -556,62 +576,62 @@ class AccountService:
account_integrate = AccountIntegrate(
account_id=account.id, provider=provider, open_id=open_id, encrypted_token=""
)
db.session.add(account_integrate)
session.add(account_integrate)
db.session.commit()
session.commit()
logger.info("Account %s linked %s account %s.", account.id, provider, open_id)
except Exception as e:
logger.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id)
raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod
def close_account(account: Account):
def close_account(account: Account, *, session: scoped_session | Session):
"""Close account"""
account.status = AccountStatus.CLOSED
db.session.commit()
session.commit()
@staticmethod
def update_account(account, **kwargs):
def update_account(account: Account, *, session: scoped_session | Session, **kwargs):
"""Update account fields"""
account = db.session.merge(account)
account = session.merge(account)
for field, value in kwargs.items():
if hasattr(account, field):
setattr(account, field, value)
else:
raise AttributeError(f"Invalid field: {field}")
db.session.commit()
session.commit()
return account
@staticmethod
def update_account_email(account: Account, email: str) -> Account:
def update_account_email(account: Account, email: str, session: scoped_session | Session) -> Account:
"""Update account email"""
account.email = email
account_integrate = db.session.scalar(
account_integrate = session.scalar(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
)
if account_integrate:
db.session.delete(account_integrate)
db.session.add(account)
db.session.commit()
session.delete(account_integrate)
session.add(account)
session.commit()
return account
@staticmethod
def update_login_info(account: Account, *, ip_address: str):
def update_login_info(account: Account, session: scoped_session | Session, *, ip_address: str):
"""Update last login time and ip"""
account.last_login_at = naive_utc_now()
account.last_login_ip = ip_address
db.session.add(account)
db.session.commit()
session.add(account)
session.commit()
@staticmethod
def login(account: Account, *, ip_address: str | None = None) -> TokenPair:
def login(account: Account, *, session: scoped_session | Session, ip_address: str | None = None) -> TokenPair:
if ip_address:
AccountService.update_login_info(account=account, ip_address=ip_address)
AccountService.update_login_info(account=account, session=session, ip_address=ip_address)
if account.status == AccountStatus.PENDING:
account.status = AccountStatus.ACTIVE
db.session.commit()
session.commit()
access_token = AccountService.get_account_jwt_token(account=account)
refresh_token = _generate_refresh_token()
@ -628,13 +648,13 @@ class AccountService:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
@staticmethod
def refresh_token(refresh_token: str) -> TokenPair:
def refresh_token(refresh_token: str, *, session: scoped_session | Session) -> TokenPair:
# Verify the refresh token
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
if not account_id:
raise ValueError("Invalid refresh token")
account = AccountService.load_user(account_id.decode("utf-8"))
account = AccountService.load_user(account_id.decode("utf-8"), session)
if not account:
raise ValueError("Invalid account")
@ -649,8 +669,8 @@ class AccountService:
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token)
@staticmethod
def load_logged_in_account(*, account_id: str):
return AccountService.load_user(account_id)
def load_logged_in_account(*, account_id: str, session: scoped_session | Session):
return AccountService.load_user(account_id, session)
@classmethod
def send_reset_password_email(
@ -1002,7 +1022,7 @@ class AccountService:
TokenManager.revoke_token(token, "email_code_login")
@classmethod
def get_user_through_email(cls, email: str):
def get_user_through_email(cls, email: str, *, session: scoped_session | Session):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email):
raise AccountRegisterError(
description=(
@ -1011,7 +1031,7 @@ class AccountService:
)
)
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
account = session.scalar(select(Account).where(Account.email == email).limit(1))
if not account:
return None
@ -1210,13 +1230,19 @@ class AccountService:
return False
@staticmethod
def check_email_unique(email: str) -> bool:
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
def check_email_unique(email: str, *, session: scoped_session | Session) -> bool:
return session.scalar(select(Account).where(Account.email == email).limit(1)) is None
class TenantService:
@staticmethod
def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant:
def create_tenant(
name: str,
is_setup: bool | None = False,
is_from_dashboard: bool | None = False,
*,
session: scoped_session | Session,
) -> Tenant:
"""Create tenant"""
if (
not FeatureService.get_system_features().is_allow_create_workspace
@ -1228,8 +1254,8 @@ class TenantService:
raise NotAllowedCreateWorkspace()
tenant = Tenant(name=name)
db.session.add(tenant)
db.session.commit()
session.add(tenant)
session.commit()
for category in TenantPluginAutoUpgradeStrategy.PluginCategory:
plugin_upgrade_strategy = TenantPluginAutoUpgradeStrategy(
@ -1241,11 +1267,11 @@ class TenantService:
exclude_plugins=[],
include_plugins=[],
)
db.session.add(plugin_upgrade_strategy)
db.session.commit()
session.add(plugin_upgrade_strategy)
session.commit()
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
session.commit()
from services.credit_pool_service import CreditPoolService
@ -1254,9 +1280,11 @@ class TenantService:
return tenant
@staticmethod
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
def create_owner_tenant_if_not_exist(
account: Account, name: str | None = None, is_setup: bool | None = False, *, session: scoped_session | Session
):
"""Check if user have a workspace or not"""
available_ta = db.session.scalar(
available_ta = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id)
.order_by(TenantAccountJoin.id.asc())
@ -1275,10 +1303,10 @@ class TenantService:
raise WorkspacesLimitExceededError()
if name:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
tenant = TenantService.create_tenant(name=name, is_setup=is_setup, session=session)
else:
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup, session=session)
TenantService.create_tenant_member(tenant, account, session, role="owner")
if dify_config.RBAC_ENABLED:
owner_role_id = AccountService._resolve_legacy_role_id(str(tenant.id), account.id, TenantAccountRole.OWNER)
RBACService.MemberRoles.replace(
@ -1288,16 +1316,16 @@ class TenantService:
role_ids=[owner_role_id],
)
account.current_tenant = tenant
db.session.commit()
session.commit()
tenant_was_created.send(tenant)
@staticmethod
def create_tenant_member(
tenant: Tenant, account: Account, session: scoped_session, role: str = "normal"
tenant: Tenant, account: Account, session: scoped_session | Session, role: str = "normal"
) -> TenantAccountJoin:
"""Create tenant member"""
if role == TenantAccountRole.OWNER:
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER], session=session):
logger.error("Tenant %s has already an owner.", tenant.id)
raise Exception("Tenant already has an owner.")
@ -1318,10 +1346,10 @@ class TenantService:
return ta
@staticmethod
def get_join_tenants(account: Account) -> list[Tenant]:
def get_join_tenants(account: Account, *, session: scoped_session | Session) -> list[Tenant]:
"""Get account join tenants"""
return list(
db.session.scalars(
session.scalars(
select(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
@ -1340,7 +1368,7 @@ class TenantService:
membership + pick the default workspace.
``session`` is injected by the caller so this service stays free
of the Flask-scoped ``db.session`` import.
of a Flask-scoped session import.
No tenant-status filter: parity with the legacy controller query
(the openapi identity endpoint listed all joined tenants).
@ -1413,7 +1441,7 @@ class TenantService:
bearers (no account) collapse to the non-member path. Mirrors the
session-injection style of :meth:`account_belongs_to_tenant` rather
than :meth:`get_user_role`, which loads full ``Account``/``Tenant``
objects against the Flask-scoped ``db.session``.
objects against the Flask-scoped session.
"""
if not account_id:
return None
@ -1479,13 +1507,13 @@ class TenantService:
).first()
@staticmethod
def get_current_tenant_by_account(account: Account):
def get_current_tenant_by_account(account: Account, *, session: scoped_session | Session):
"""Get tenant by account and add the role"""
tenant = account.current_tenant
if not tenant:
raise TenantNotFoundError("Tenant not found.")
ta = db.session.scalar(
ta = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
@ -1497,14 +1525,14 @@ class TenantService:
return tenant
@staticmethod
def switch_tenant(account: Account, tenant_id: str | None = None):
def switch_tenant(account: Account, tenant_id: str | None = None, *, session: scoped_session | Session):
"""Switch the current workspace for the account"""
# Ensure tenant_id is provided
if tenant_id is None:
raise ValueError("Tenant ID must be provided.")
tenant_account_join = db.session.scalar(
tenant_account_join = session.scalar(
select(TenantAccountJoin)
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
.where(
@ -1518,7 +1546,7 @@ class TenantService:
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
db.session.execute(
session.execute(
update(TenantAccountJoin)
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id)
.values(current=False)
@ -1527,10 +1555,10 @@ class TenantService:
tenant_account_join.last_opened_at = naive_utc_now()
# Set the current tenant for the account
account.set_tenant_id(tenant_account_join.tenant_id)
db.session.commit()
session.commit()
@staticmethod
def get_tenant_members(tenant: Tenant) -> list[Account]:
def get_tenant_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]:
"""Get tenant members"""
stmt = (
select(Account, TenantAccountJoin.role)
@ -1542,14 +1570,14 @@ class TenantService:
# Initialize an empty list to store the updated accounts
updated_accounts = []
for account, role in db.session.execute(stmt):
for account, role in session.execute(stmt):
account.role = role
updated_accounts.append(account)
return updated_accounts
@staticmethod
def get_dataset_operator_members(tenant: Tenant) -> list[Account]:
def get_dataset_operator_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]:
"""Get dataset admin members"""
stmt = (
select(Account, TenantAccountJoin.role)
@ -1562,20 +1590,20 @@ class TenantService:
# Initialize an empty list to store the updated accounts
updated_accounts = []
for account, role in db.session.execute(stmt):
for account, role in session.execute(stmt):
account.role = role
updated_accounts.append(account)
return updated_accounts
@staticmethod
def has_roles(tenant: Tenant, roles: list[TenantAccountRole]) -> bool:
def has_roles(tenant: Tenant, roles: list[TenantAccountRole], *, session: scoped_session | Session) -> bool:
"""Check if user has any of the given roles for a tenant"""
if not all(isinstance(role, TenantAccountRole) for role in roles):
raise ValueError("all roles must be TenantAccountRole")
return (
db.session.scalar(
session.scalar(
select(TenantAccountJoin)
.where(
TenantAccountJoin.tenant_id == tenant.id,
@ -1587,9 +1615,11 @@ class TenantService:
)
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
def get_user_role(
account: Account, tenant: Tenant, *, session: scoped_session | Session
) -> TenantAccountRole | None:
"""Get the role of the current account for a given tenant"""
join = db.session.scalar(
join = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
@ -1597,12 +1627,14 @@ class TenantService:
return TenantAccountRole(join.role) if join else None
@staticmethod
def get_tenant_count() -> int:
def get_tenant_count(*, session: scoped_session | Session) -> int:
"""Get tenant count"""
return cast(int, db.session.scalar(select(func.count(Tenant.id))))
return cast(int, session.scalar(select(func.count(Tenant.id))))
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
def check_member_permission(
tenant: Tenant, operator: Account, member: Account | None, action: str, *, session: scoped_session | Session
):
"""Check member permission"""
if action not in {"add", "remove", "update"}:
raise InvalidActionError("Invalid action.")
@ -1636,7 +1668,7 @@ class TenantService:
"update": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
}
ta_operator = db.session.scalar(
ta_operator = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id)
.limit(1)
@ -1646,7 +1678,7 @@ class TenantService:
raise NoPermissionError(f"No permission to {action} member.")
if action == "remove" and ta_operator.role == TenantAccountRole.ADMIN and member:
ta_member = db.session.scalar(
ta_member = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
.limit(1)
@ -1655,7 +1687,9 @@ class TenantService:
raise NoPermissionError(f"No permission to {action} member.")
@staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
def remove_member_from_tenant(
tenant: Tenant, account: Account, operator: Account, *, session: scoped_session | Session
):
"""Remove member from tenant.
Apps and datasets maintained by the removed member are reassigned to
@ -1667,9 +1701,9 @@ class TenantService:
if operator.id == account.id:
raise CannotOperateSelfError("Cannot operate self.")
TenantService.check_member_permission(tenant, operator, account, "remove")
TenantService.check_member_permission(tenant, operator, account, "remove", session=session)
ta = db.session.scalar(
ta = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
@ -1686,7 +1720,7 @@ class TenantService:
if dify_config.RBAC_ENABLED:
owner_id = AccountService.get_rbac_workspace_owner_account_id(str(tenant.id), str(operator.id))
else:
owner_id = db.session.scalar(
owner_id = session.scalar(
select(TenantAccountJoin.account_id)
.where(
TenantAccountJoin.tenant_id == tenant.id,
@ -1697,7 +1731,7 @@ class TenantService:
if owner_id is None:
raise ValueError(f"Workspace owner not found for tenant {tenant.id}.")
db.session.execute(
session.execute(
update(App)
.where(
App.tenant_id == tenant.id,
@ -1705,7 +1739,7 @@ class TenantService:
)
.values(maintainer=owner_id)
)
db.session.execute(
session.execute(
update(Dataset)
.where(
Dataset.tenant_id == tenant.id,
@ -1713,23 +1747,23 @@ class TenantService:
)
.values(maintainer=owner_id)
)
db.session.delete(ta)
session.delete(ta)
# Clean up orphaned pending accounts (invited but never activated)
should_delete_account = False
if account.status == AccountStatus.PENDING:
# autoflush flushes ta deletion before this query, so 0 means no remaining joins
remaining_joins = (
db.session.scalar(
session.scalar(
select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id)
)
or 0
)
if remaining_joins == 0:
db.session.delete(account)
session.delete(account)
should_delete_account = True
db.session.commit()
session.commit()
if should_delete_account:
logger.info(
@ -1755,12 +1789,14 @@ class TenantService:
)
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
def update_member_role(
tenant: Tenant, member: Account, new_role: str, operator: Account, *, session: scoped_session | Session
):
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
TenantService.check_member_permission(tenant, operator, member, "update", session=session)
new_tenant_role = TenantAccountRole(new_role)
target_member_join = db.session.scalar(
target_member_join = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
.limit(1)
@ -1769,7 +1805,7 @@ class TenantService:
if not target_member_join:
raise MemberNotInTenantError("Member not in tenant.")
operator_role = TenantService.get_user_role(operator, tenant)
operator_role = TenantService.get_user_role(operator, tenant, session=session)
target_role = TenantAccountRole(target_member_join.role)
if operator_role == TenantAccountRole.ADMIN and (TenantAccountRole.OWNER in {target_role, new_tenant_role}):
raise NoPermissionError("No permission to update member.")
@ -1779,7 +1815,7 @@ class TenantService:
if new_role == "owner":
# Find the current owner and change their role to 'admin'
current_owner_join = db.session.scalar(
current_owner_join = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
.limit(1)
@ -1815,7 +1851,7 @@ class TenantService:
)
else:
target_member_join.role = new_tenant_role
db.session.commit()
session.commit()
@staticmethod
def get_custom_config(tenant_id: str):
@ -1824,13 +1860,13 @@ class TenantService:
return tenant.custom_config_dict
@staticmethod
def is_owner(account: Account, tenant: Tenant) -> bool:
return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER
def is_owner(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool:
return TenantService.get_user_role(account, tenant, session=session) == TenantAccountRole.OWNER
@staticmethod
def is_member(account: Account, tenant: Tenant) -> bool:
def is_member(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool:
"""Check if the account is a member of the tenant"""
return TenantService.get_user_role(account, tenant) is not None
return TenantService.get_user_role(account, tenant, session=session) is not None
class RegisterService:
@ -1839,7 +1875,16 @@ class RegisterService:
return f"member_invite:token:{token}"
@classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None):
def setup(
cls,
email: str,
name: str,
password: str,
ip_address: str,
language: str | None,
*,
session: scoped_session | Session,
):
"""
Setup dify
@ -1856,22 +1901,23 @@ class RegisterService:
interface_language=get_valid_language(language),
password=password,
is_setup=True,
session=session,
)
account.last_login_ip = ip_address
account.initialized_at = naive_utc_now()
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True, session=session)
dify_setup = DifySetup(version=dify_config.project.version)
db.session.add(dify_setup)
db.session.commit()
session.add(dify_setup)
session.commit()
except Exception as e:
db.session.execute(delete(DifySetup))
db.session.execute(delete(TenantAccountJoin))
db.session.execute(delete(Account))
db.session.execute(delete(Tenant))
db.session.commit()
session.execute(delete(DifySetup))
session.execute(delete(TenantAccountJoin))
session.execute(delete(Account))
session.execute(delete(Tenant))
session.commit()
logger.exception("Setup account failed, email: %s, name: %s", email, name)
raise ValueError(f"Setup failed: {e}")
@ -1889,9 +1935,11 @@ class RegisterService:
is_setup: bool | None = False,
create_workspace_required: bool | None = True,
timezone: str | None = None,
*,
session: scoped_session | Session,
) -> Account:
"""Register account"""
db.session.begin_nested()
session.begin_nested()
try:
interface_language = get_valid_language(language)
account = AccountService.create_account(
@ -1901,12 +1949,13 @@ class RegisterService:
password=password,
is_setup=is_setup,
timezone=timezone,
session=session,
)
account.status = status or AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
AccountService.link_account_integrate(provider, open_id, account, session=session)
if (
FeatureService.get_system_features().is_allow_create_workspace
@ -1914,27 +1963,27 @@ class RegisterService:
and FeatureService.get_system_features().license.workspaces.is_available()
):
try:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=session)
TenantService.create_tenant_member(tenant, account, session, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
except Exception:
_try_join_enterprise_default_workspace(str(account.id))
raise
db.session.commit()
session.commit()
_try_join_enterprise_default_workspace(str(account.id))
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
session.rollback()
logger.exception("Register failed")
raise AccountRegisterError("Workspace is not allowed to create.")
except AccountRegisterError as are:
db.session.rollback()
session.rollback()
logger.exception("Register failed")
raise are
except Exception as e:
db.session.rollback()
session.rollback()
logger.exception("Register failed")
raise AccountRegisterError(f"Registration failed: {e}") from e
@ -1942,7 +1991,14 @@ class RegisterService:
@classmethod
def invite_new_member(
cls, tenant: Tenant, email: str, language: str | None, role: str = "normal", inviter: Account | None = None
cls,
tenant: Tenant,
email: str,
language: str | None,
role: str = "normal",
inviter: Account | None = None,
*,
session: scoped_session | Session,
) -> str:
if not inviter:
raise ValueError("Inviter is required")
@ -1960,7 +2016,7 @@ class RegisterService:
requires_setup = False
if not account:
TenantService.check_member_permission(tenant, inviter, None, "add")
TenantService.check_member_permission(tenant, inviter, None, "add", session=session)
name = normalized_email.split("@")[0]
account = cls.register(
@ -1969,13 +2025,14 @@ class RegisterService:
language=language,
status=AccountStatus.PENDING,
is_setup=True,
session=session,
)
TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role)
TenantService.switch_tenant(account, tenant.id)
TenantService.create_tenant_member(tenant, account, session, tenant_join_role)
TenantService.switch_tenant(account, tenant.id, session=session)
requires_setup = True
else:
TenantService.check_member_permission(tenant, inviter, account, "add")
ta = db.session.scalar(
TenantService.check_member_permission(tenant, inviter, account, "add", session=session)
ta = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.limit(1)
@ -1983,7 +2040,7 @@ class RegisterService:
requires_setup = account.status == AccountStatus.PENDING
if not ta and (account.status == AccountStatus.PENDING or dify_config.RBAC_ENABLED):
TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role)
TenantService.create_tenant_member(tenant, account, session, tenant_join_role)
# Support resend invitation email when the account is pending status
if account.status != AccountStatus.PENDING:
@ -2052,20 +2109,20 @@ class RegisterService:
@classmethod
def get_invitation_if_token_valid(
cls, workspace_id: str | None, email: str | None, token: str
cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session
) -> InvitationDetailDict | None:
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
tenant = db.session.scalar(
tenant = session.scalar(
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
)
if not tenant:
return None
account = db.session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1))
account = session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1))
if not account:
return None
@ -2105,13 +2162,13 @@ class RegisterService:
@classmethod
def get_invitation_with_case_fallback(
cls, workspace_id: str | None, email: str | None, token: str
cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session
) -> InvitationDetailDict | None:
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token, session=session)
if invitation or not email or email == email.lower():
return invitation
normalized_email = email.lower()
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token, session=session)
def _generate_refresh_token(length: int = 64):

View File

@ -4,6 +4,7 @@ from typing import TypedDict
import pandas as pd
from sqlalchemy import delete, or_, select, update
from sqlalchemy.orm import scoped_session
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@ -300,17 +301,19 @@ class AppAnnotationService:
return annotation
@classmethod
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
def update_app_annotation_directly(
cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str, session: scoped_session
):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
app = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation = db.session.get(MessageAnnotation, annotation_id)
annotation = session.get(MessageAnnotation, annotation_id)
if not annotation:
raise NotFound("Annotation not found")
@ -326,9 +329,9 @@ class AppAnnotationService:
annotation.content = answer
annotation.question = question
db.session.commit()
session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = db.session.scalar(
app_annotation_setting = session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
@ -344,33 +347,33 @@ class AppAnnotationService:
return annotation
@classmethod
def delete_app_annotation(cls, app_id: str, annotation_id: str):
def delete_app_annotation(cls, app_id: str, annotation_id: str, session: scoped_session):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = db.session.scalar(
app = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation = db.session.get(MessageAnnotation, annotation_id)
annotation = session.get(MessageAnnotation, annotation_id)
if not annotation:
raise NotFound("Annotation not found")
db.session.delete(annotation)
session.delete(annotation)
annotation_hit_histories = db.session.scalars(
annotation_hit_histories = session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
).all()
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)
session.delete(annotation_hit_history)
db.session.commit()
session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = db.session.scalar(
app_annotation_setting = session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)

View File

@ -181,6 +181,7 @@ class SystemFeatureModel(FeatureResponseModel):
enable_creators_platform: bool = False
enable_trial_app: bool = False
enable_explore_banner: bool = False
enable_learn_app: bool = True
rbac_enabled: bool = False
@ -282,6 +283,7 @@ class FeatureService:
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
system_features.enable_learn_app = dify_config.ENABLE_LEARN_APP
@classmethod
def _fulfill_trial_models_from_env(cls) -> list[str]:

View File

@ -91,4 +91,4 @@ class OAuthServerService:
user_id_str = user_account_id.decode("utf-8")
return AccountService.load_user(user_id_str)
return AccountService.load_user(user_id_str, db.session)

View File

@ -5,6 +5,7 @@ from typing import Any, override
from flask import current_app
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
from services.recommend_app.recommend_app_type import RecommendAppType
@ -25,6 +26,11 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
result = self.fetch_recommended_apps_from_builtin(language)
return result
@override
def get_learn_dify_apps(self, language: str):
result = DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db(language)
return result
@override
def get_recommend_app_detail(self, app_id: str):
result = self.fetch_recommended_app_detail_from_builtin(app_id)

View File

@ -49,6 +49,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
result = self.fetch_recommended_apps_from_db(language)
return result
@override
def get_learn_dify_apps(self, language: str) -> RecommendedAppsResultDict:
result = self.fetch_learn_dify_apps_from_db(language)
return result
@override
def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None:
result = self.fetch_recommended_app_detail_from_db(app_id)

View File

@ -6,6 +6,8 @@ class RecommendAppRetrievalBase(Protocol):
def get_recommended_apps_and_categories(self, language: str) -> Any: ...
def get_learn_dify_apps(self, language: str) -> Any: ...
def get_recommend_app_detail(self, app_id: str) -> Any: ...
def get_type(self) -> str: ...

View File

@ -2,15 +2,28 @@ import logging
from typing import Any, override
import httpx
from flask import has_request_context, request
from configs import dify_config
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
from services.recommend_app.recommend_app_type import RecommendAppType
logger = logging.getLogger(__name__)
def _current_origin_headers() -> dict[str, str]:
origin = request.headers.get("Origin") if has_request_context() else None
if origin:
return {"Origin": origin}
console_web_url = getattr(dify_config, "CONSOLE_WEB_URL", "")
if not isinstance(console_web_url, str) or not console_web_url:
return {}
return {"Origin": console_web_url}
class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
"""
Retrieval recommended app from dify official.
@ -37,6 +50,15 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language)
return result
@override
def get_learn_dify_apps(self, language: str):
try:
result = self.fetch_learn_dify_apps_from_dify_official(language)
except Exception as e:
logger.warning("fetch learn dify apps from dify official failed: %s, switch to database.", e)
result = DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db(language)
return result
@override
def get_type(self) -> str:
return RecommendAppType.REMOTE
@ -50,7 +72,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
"""
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/apps/{app_id}"
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
response = httpx.get(url, headers=_current_origin_headers(), timeout=httpx.Timeout(10.0, connect=3.0))
if response.status_code != 200:
return None
data: dict[str, Any] = response.json()
@ -65,9 +87,25 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
"""
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/apps?language={language}"
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
response = httpx.get(url, headers=_current_origin_headers(), timeout=httpx.Timeout(10.0, connect=3.0))
if response.status_code != 200:
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
result: dict[str, Any] = response.json()
return result
@classmethod
def fetch_learn_dify_apps_from_dify_official(cls, language: str):
"""
Fetch Learn Dify apps from dify official.
:param language: language
:return:
"""
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
url = f"{domain}/apps/learn-dify?language={language}"
response = httpx.get(url, headers=_current_origin_headers(), timeout=httpx.Timeout(10.0, connect=3.0))
if response.status_code != 200:
raise ValueError(f"fetch learn dify apps failed, status code: {response.status_code}")
result: dict[str, Any] = response.json()
return result

View File

@ -6,7 +6,6 @@ from sqlalchemy.orm import scoped_session
from configs import dify_config
from models.model import AccountTrialAppRecord, TrialApp
from services.feature_service import FeatureService
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@ -38,11 +37,13 @@ class RecommendedAppService:
@classmethod
def get_learn_dify_apps(cls, session: scoped_session, language: str) -> dict[str, Any]:
"""
Get database-backed recommended apps marked as Learn Dify.
Get recommended apps marked for the Learn Dify section.
:param language: language
:return:
"""
result = DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db(language)
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result = retrieval_instance.get_learn_dify_apps(language)
if FeatureService.get_system_features().enable_trial_app:
for app in result["recommended_apps"]:

View File

@ -36,7 +36,9 @@ class WorkspaceService:
feature = FeatureService.get_features(tenant.id, exclude_vector_space=True)
can_replace_logo = feature.can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
if can_replace_logo and TenantService.has_roles(
tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], session=db.session
):
base_url = dify_config.FILES_URL
replace_webapp_logo = (
f"{base_url}/files/workspaces/{tenant.id}/webapp-logo"

View File

@ -84,6 +84,7 @@ def setup_account(request) -> Generator[Account, None, None]:
password=secrets.token_hex(16),
ip_address="localhost",
language="en-US",
session=db.session,
)
with _CACHED_APP.test_request_context():

View File

@ -548,6 +548,7 @@ class TestAccountGeneration:
provider="github",
language="en-US",
timezone=None,
session=ANY,
)
else:
mock_register_service.register.assert_not_called()
@ -581,6 +582,7 @@ class TestAccountGeneration:
provider="github",
language="en-US",
timezone=None,
session=ANY,
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@ -612,6 +614,7 @@ class TestAccountGeneration:
provider="github",
language="zh-Hans",
timezone="Asia/Shanghai",
session=ANY,
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@ -643,6 +646,7 @@ class TestAccountGeneration:
provider="github",
language="zh-Hans",
timezone=None,
session=ANY,
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
@ -673,7 +677,7 @@ class TestAccountGeneration:
assert result == mock_account
assert oauth_new_user is False
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace", session=ANY)
mock_tenant_service.create_tenant_member.assert_called_once_with(
mock_new_tenant, mock_account, ANY, role="owner"
)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import uuid
from collections.abc import Callable, Iterator
from collections.abc import Callable, Generator
from contextlib import contextmanager
from typing import Literal
from unittest.mock import patch
@ -12,7 +12,6 @@ from flask import Flask
from sqlalchemy.orm import Session
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx
from models import Account, Tenant
from services.account_service import AccountService, TenantService
@ -46,20 +45,25 @@ def make_account(db_session_with_containers: Session) -> Callable[..., Account]:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
if with_owner_tenant:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(
account, name=fake.company(), session=db_session_with_containers
)
return account
return _make
def add_tenant_for_account(account: Account, *, role: str = "normal", name: str = "Second WS") -> Tenant:
def add_tenant_for_account(
account: Account, *, session: Session, role: str = "normal", name: str = "Second WS"
) -> Tenant:
"""Create an additional tenant and join ``account`` to it (real service calls)."""
with patch("services.account_service.FeatureService") as mock_feature_service:
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
tenant = TenantService.create_tenant(name=name)
TenantService.create_tenant_member(tenant, account, db.session, role=role)
tenant = TenantService.create_tenant(name=name, session=session)
TenantService.create_tenant_member(tenant, account, session, role=role)
return tenant
@ -93,7 +97,7 @@ def account_auth_context(
*,
token_id: uuid.UUID,
client_id: str = "integration-cli",
) -> Iterator[AuthContext]:
) -> Generator[AuthContext]:
"""Publish an account ``AuthContext`` for handlers that read ``get_auth_ctx()``.
The auth pipeline normally sets this ContextVar; the integration suite

View File

@ -4,6 +4,7 @@ from collections.abc import Callable
from inspect import unwrap
from flask import Flask
from sqlalchemy.orm import Session
from controllers.openapi.account import AccountApi
from models import Account
@ -34,11 +35,13 @@ class TestAccountInfo:
# the only workspace the account belongs to.
assert result.default_workspace_id == owner_tenant.id
def test_lists_all_joined_workspaces(self, app: Flask, make_account: Callable[..., Account]) -> None:
def test_lists_all_joined_workspaces(
self, app: Flask, db_session_with_containers: Session, make_account: Callable[..., Account]
) -> None:
account = make_account()
owner_tenant = account.current_tenant
assert owner_tenant is not None
second = add_tenant_for_account(account, role="normal", name="Second WS")
second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS")
api = AccountApi()
with app.test_request_context("/openapi/v1/account"):

View File

@ -81,8 +81,9 @@ def _app_and_account(db_session: Session, *, mode: str = "chat") -> tuple[App, A
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session)
tenant = account.current_tenant
assert tenant is not None
app_args = CreateAppParams(

View File

@ -41,7 +41,7 @@ class TestWorkspacesList:
account = make_account()
owner_tenant = account.current_tenant
assert owner_tenant is not None
second = add_tenant_for_account(account, role="normal", name="Second WS")
second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS")
api = WorkspacesApi()
with app.test_request_context("/openapi/v1/workspaces"):
@ -90,7 +90,9 @@ class TestWorkspaceSwitch:
account = make_account()
owner_tenant = account.current_tenant
assert owner_tenant is not None
target = add_tenant_for_account(account, role="normal", name="Switch Target")
target = add_tenant_for_account(
account, session=db_session_with_containers, role="normal", name="Switch Target"
)
api = WorkspaceSwitchApi()
with app.test_request_context(f"/openapi/v1/workspaces/{target.id}/switch", method="POST"):

View File

@ -27,8 +27,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -88,8 +89,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(
@ -141,8 +143,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(
@ -194,8 +197,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(
@ -257,8 +261,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(
@ -291,8 +296,11 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
account1, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company())
tenant1 = account1.current_tenant
account2 = AccountService.create_account(
@ -300,8 +308,11 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
account2, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company())
tenant2 = account2.current_tenant
# Create dataset for tenant1
@ -367,8 +378,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Don't create any datasets
@ -391,8 +403,9 @@ class TestGetAvailableDatasetsIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create multiple datasets
@ -452,8 +465,9 @@ class TestKnowledgeRetrievalIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(
@ -520,8 +534,9 @@ class TestKnowledgeRetrievalIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset but no documents
@ -568,8 +583,9 @@ class TestKnowledgeRetrievalIntegration:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
dataset = Dataset(

View File

@ -114,8 +114,9 @@ class TestAgentService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app with realistic data

View File

@ -81,8 +81,9 @@ class TestAnnotationService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Setup app creation arguments
@ -280,7 +281,9 @@ class TestAnnotationService:
"question": fake.sentence(),
"answer": fake.text(max_nb_chars=200),
}
updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id)
updated_annotation = AppAnnotationService.update_app_annotation_directly(
updated_args, app.id, annotation.id, db_session_with_containers
)
# Verify annotation was updated correctly
assert updated_annotation.id == annotation.id
@ -567,7 +570,7 @@ class TestAnnotationService:
annotation_id = annotation.id
# Delete the annotation
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
AppAnnotationService.delete_app_annotation(app.id, annotation_id, db_session_with_containers)
# Verify annotation was deleted
@ -595,7 +598,7 @@ class TestAnnotationService:
# Try to delete annotation with non-existent app
with pytest.raises(NotFound, match="App not found"):
AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id)
AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id, db_session_with_containers)
def test_delete_app_annotation_annotation_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -609,7 +612,7 @@ class TestAnnotationService:
# Try to delete non-existent annotation
with pytest.raises(NotFound, match="Annotation not found"):
AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id)
AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id, db_session_with_containers)
def test_enable_app_annotation_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -1225,7 +1228,9 @@ class TestAnnotationService:
"question": fake.sentence(),
"answer": fake.text(max_nb_chars=200),
}
updated_annotation = AppAnnotationService.update_app_annotation_directly(updated_args, app.id, annotation.id)
updated_annotation = AppAnnotationService.update_app_annotation_directly(
updated_args, app.id, annotation.id, db_session_with_containers
)
# Verify annotation was updated correctly
assert updated_annotation.id == annotation.id
@ -1295,7 +1300,7 @@ class TestAnnotationService:
mock_external_service_dependencies["delete_task"].delay.reset_mock()
# Delete the annotation
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
AppAnnotationService.delete_app_annotation(app.id, annotation_id, db_session_with_containers)
# Verify annotation was deleted
deleted_annotation = (

View File

@ -57,8 +57,9 @@ class TestAPIBasedExtensionService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
return account, tenant

View File

@ -145,8 +145,11 @@ class TestAppDslService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
account, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_args = CreateAppParams(
name=fake.company(),

View File

@ -166,8 +166,9 @@ class TestAppGenerateService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from services.app_service import AppService, CreateAppParams

View File

@ -62,8 +62,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Setup app creation arguments
@ -119,8 +120,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency
@ -162,8 +164,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -210,8 +213,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency
@ -261,8 +265,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from services.app_service import AppListParams, AppService, CreateAppParams
@ -344,8 +349,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from models import AppStar
@ -404,8 +410,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from services.app_service import AppService, CreateAppParams, StarredAppListParams
@ -500,8 +507,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency
@ -566,14 +574,18 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
first_account, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(first_account, name=fake.company())
tenant = first_account.current_tenant
second_account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
from services.app_service import AppListParams, AppService, CreateAppParams
@ -623,8 +635,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency
@ -685,8 +698,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -755,8 +769,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from services.app_service import AppService, CreateAppParams
@ -807,8 +822,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
from services.app_service import AppService, CreateAppParams
@ -857,8 +873,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -910,8 +927,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -971,8 +989,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1030,8 +1049,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1089,8 +1109,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1139,8 +1160,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1190,8 +1212,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1249,8 +1272,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1287,8 +1311,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1326,8 +1351,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app first
@ -1375,8 +1401,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
# Import here to avoid circular dependency
from services.app_service import CreateAppParams
@ -1411,8 +1438,9 @@ class TestAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency

View File

@ -98,8 +98,9 @@ class TestMessageService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Setup app creation arguments
@ -648,8 +649,11 @@ class TestMessageService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
other_account, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company())
# Test getting message with different user
with pytest.raises(MessageNotExistsError):

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import uuid
from typing import cast
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
from uuid import uuid4
import pytest
@ -172,4 +172,4 @@ class TestOAuthServerServiceTokenOperations:
result = OAuthServerService.validate_oauth_access_token("client-1", "access-token")
assert result is expected_user
mock_load.assert_called_once_with("user-88")
mock_load.assert_called_once_with("user-88", ANY)

View File

@ -51,8 +51,9 @@ class TestOpsService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
app_service = AppService()
app = app_service.create_app(

View File

@ -267,36 +267,45 @@ class TestRecommendedAppServiceGetDetail:
class TestRecommendedAppServiceGetLearnDifyApps:
def test_returns_database_learn_dify_apps_without_remote_factory(self, monkeypatch: pytest.MonkeyPatch) -> None:
@patch("services.recommended_app_service.FeatureService", autospec=True)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config")
def test_uses_configured_retrieval_source(
self, mock_config: MagicMock, mock_factory_class: MagicMock, mock_feature_service: MagicMock
) -> None:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "remote"
mock_feature_service.get_system_features.return_value = SimpleNamespace(enable_trial_app=False)
expected_app = RecommendedAppPayload(app_id="app-1", category="Workflow")
mock_database_retrieval = MagicMock()
mock_database_retrieval.fetch_learn_dify_apps_from_db.return_value = {
mock_instance = MagicMock()
mock_instance.get_learn_dify_apps.return_value = {
"recommended_apps": [expected_app],
"categories": ["Workflow"],
}
monkeypatch.setattr(service_module, "DatabaseRecommendAppRetrieval", mock_database_retrieval)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
)
factory_mock = MagicMock()
monkeypatch.setattr(service_module.RecommendAppRetrievalFactory, "get_recommend_app_factory", factory_mock)
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = RecommendedAppService.get_learn_dify_apps(db.session, "en-US")
assert result == {"recommended_apps": [expected_app]}
mock_database_retrieval.fetch_learn_dify_apps_from_db.assert_called_once_with("en-US")
factory_mock.assert_not_called()
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
mock_instance.get_learn_dify_apps.assert_called_once_with("en-US")
def test_sets_can_trial_when_trial_feature_enabled(self, monkeypatch: pytest.MonkeyPatch) -> None:
@patch("services.recommended_app_service.dify_config")
def test_sets_can_trial_when_trial_feature_enabled(
self, mock_config: MagicMock, monkeypatch: pytest.MonkeyPatch
) -> None:
mock_config.HOSTED_FETCH_APP_TEMPLATES_MODE = "db"
app = RecommendedAppPayload(app_id="app-1", category="Workflow")
mock_database_retrieval = MagicMock()
mock_database_retrieval.fetch_learn_dify_apps_from_db.return_value = {
mock_retrieval_instance = MagicMock()
mock_retrieval_instance.get_learn_dify_apps.return_value = {
"recommended_apps": [app],
"categories": ["Workflow"],
}
monkeypatch.setattr(service_module, "DatabaseRecommendAppRetrieval", mock_database_retrieval)
mock_retrieval_factory = MagicMock(return_value=mock_retrieval_instance)
monkeypatch.setattr(
service_module.RecommendAppRetrievalFactory,
"get_recommend_app_factory",
MagicMock(return_value=mock_retrieval_factory),
)
monkeypatch.setattr(
service_module.FeatureService,
"get_system_features",

View File

@ -68,8 +68,9 @@ class TestSavedMessageService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app with realistic data

View File

@ -55,6 +55,7 @@ class TestTriggerProviderService:
def _create_test_account_and_tenant(
self,
mock_external_service_dependencies: MockExternalServiceDependencies,
db_session_with_containers: Session,
) -> tuple[Account, Tenant]:
"""
Helper method to create a test account and tenant for testing.
@ -83,8 +84,9 @@ class TestTriggerProviderService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
assert tenant is not None
@ -164,7 +166,9 @@ class TestTriggerProviderService:
- Database state is correctly updated
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -262,7 +266,9 @@ class TestTriggerProviderService:
- Merged credentials contain only new values
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -320,7 +326,9 @@ class TestTriggerProviderService:
- Original credentials are preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -376,7 +384,9 @@ class TestTriggerProviderService:
- UNKNOWN_VALUE is used when HIDDEN_VALUE key doesn't exist in original credentials
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -434,7 +444,9 @@ class TestTriggerProviderService:
- Original subscription state is preserved
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY
@ -474,9 +486,8 @@ class TestTriggerProviderService:
assert subscription.name == original_name
assert subscription.parameters == original_parameters
@pytest.mark.usefixtures("db_session_with_containers")
def test_rebuild_trigger_subscription_subscription_not_found(
self, mock_external_service_dependencies: MockExternalServiceDependencies
self, mock_external_service_dependencies: MockExternalServiceDependencies, db_session_with_containers: Session
) -> None:
"""
Test error when subscription is not found.
@ -485,7 +496,9 @@ class TestTriggerProviderService:
- Proper error is raised when subscription doesn't exist
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
fake_subscription_id = fake.uuid4()
@ -509,7 +522,9 @@ class TestTriggerProviderService:
- Error is raised when new name conflicts with existing subscription
"""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(mock_external_service_dependencies)
account, tenant = self._create_test_account_and_tenant(
mock_external_service_dependencies, db_session_with_containers
)
provider_id = TriggerProviderID("test_org/test_plugin/test_provider")
credential_type = CredentialType.API_KEY

View File

@ -72,8 +72,9 @@ class TestWebConversationService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app with realistic data

View File

@ -63,8 +63,9 @@ class TestWebhookService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
assert tenant is not None

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import json
import logging
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from uuid import uuid4
@ -355,7 +356,10 @@ class TestWebhookServiceTriggerExecutionWithContainers:
mock_mark_rate_limited.assert_called_once_with(tenant.id)
def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors(
self, db_session_with_containers: Session, flask_app_with_containers: Flask
self,
db_session_with_containers: Session,
flask_app_with_containers: Flask,
caplog: pytest.LogCaptureFixture,
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -367,13 +371,11 @@ class TestWebhookServiceTriggerExecutionWithContainers:
webhook_trigger = factory.create_webhook_trigger(
db_session_with_containers, app=app, account=account, node_id="node-1"
)
caplog.set_level(logging.ERROR, logger="services.trigger.webhook_service")
with (
patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
side_effect=RuntimeError("boom"),
),
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
with patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type",
side_effect=RuntimeError("boom"),
):
with pytest.raises(RuntimeError, match="boom"):
WebhookService.trigger_workflow_execution(
@ -382,7 +384,7 @@ class TestWebhookServiceTriggerExecutionWithContainers:
workflow,
)
mock_logger_exception.assert_called_once()
assert caplog.messages.count(f"Failed to trigger workflow for webhook {webhook_trigger.webhook_id}") == 1
class TestWebhookServiceRelationshipSyncWithContainers:
@ -482,7 +484,10 @@ class TestWebhookServiceRelationshipSyncWithContainers:
assert cached_payload["webhook_id"] == "cache-webhook-id-00001"
def test_sync_webhook_relationships_logs_when_lock_release_fails(
self, db_session_with_containers: Session, flask_app_with_containers: Flask
self,
db_session_with_containers: Session,
flask_app_with_containers: Flask,
caplog: pytest.LogCaptureFixture,
):
del flask_app_with_containers
factory = WebhookServiceRelationshipFactory
@ -494,14 +499,12 @@ class TestWebhookServiceRelationshipSyncWithContainers:
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = RuntimeError("release failed")
caplog.set_level(logging.ERROR, logger="services.trigger.webhook_service")
with (
patch("services.trigger.webhook_service.redis_client.lock", return_value=lock),
patch("services.trigger.webhook_service.logger.exception") as mock_logger_exception,
):
with patch("services.trigger.webhook_service.redis_client.lock", return_value=lock):
WebhookService.sync_webhook_relationships(app, workflow)
mock_logger_exception.assert_called_once()
assert caplog.messages.count(f"Failed to release lock for webhook sync, app {app.id}") == 1
def _read_cache(cache_key: str) -> dict[str, str] | None:

View File

@ -78,8 +78,9 @@ class TestWorkflowAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Import here to avoid circular dependency
@ -126,8 +127,9 @@ class TestWorkflowAppService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
return tenant, account

View File

@ -74,8 +74,9 @@ class TestWorkflowRunService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app with realistic data
@ -530,8 +531,9 @@ class TestWorkflowRunService:
name="Test User",
password="password123",
interface_language="en-US",
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers)
tenant = account.current_tenant
# Create app
@ -581,8 +583,9 @@ class TestWorkflowRunService:
name="Test User",
password="password123",
interface_language="en-US",
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers)
tenant = account.current_tenant
# Create app
@ -632,8 +635,9 @@ class TestWorkflowRunService:
name="Test User",
password="password123",
interface_language="en-US",
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant")
TenantService.create_owner_tenant_if_not_exist(account, name="test_tenant", session=db_session_with_containers)
tenant = account.current_tenant
# Create app

View File

@ -89,8 +89,9 @@ class TestWorkflowToolManageService:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create app with realistic data

View File

@ -90,8 +90,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -211,8 +212,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -255,8 +257,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Test different index types
@ -342,8 +345,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -424,8 +428,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -525,8 +530,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -625,8 +631,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -717,8 +724,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -820,8 +828,11 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(
account, name=fake.company(), session=db_session_with_containers
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
accounts.append(account)
tenants.append(tenant)
@ -926,8 +937,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset
@ -1031,8 +1043,9 @@ class TestCleanNotionDocumentTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
# Create dataset with built-in fields enabled

View File

@ -67,8 +67,9 @@ class TestDealDatasetVectorIndexTask:
name=fake.name(),
interface_language="en-US",
password=generate_valid_password(fake),
session=db_session_with_containers,
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
tenant = account.current_tenant
assert tenant is not None
return account, tenant

View File

@ -6,9 +6,11 @@ using TestContainers to ensure realistic database interactions and proper isolat
The task is responsible for removing document segments from the search index when they are disabled.
"""
import logging
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -533,7 +535,9 @@ class TestDisableSegmentsFromIndexTask:
assert result is None # Task should complete without returning a value
mock_factory.assert_called_with(doc_form)
def test_disable_segments_performance_timing(self, db_session_with_containers: Session):
def test_disable_segments_performance_timing(
self, db_session_with_containers: Session, caplog: pytest.LogCaptureFixture
):
"""
Test that the task properly measures and logs performance timing.
@ -562,21 +566,18 @@ class TestDisableSegmentsFromIndexTask:
# Mock time.perf_counter to control timing
with patch("tasks.disable_segments_from_index_task.time.perf_counter") as mock_perf_counter:
mock_perf_counter.side_effect = [1000.0, 1000.5] # 0.5 seconds execution time
caplog.set_level(logging.INFO, logger="tasks.disable_segments_from_index_task")
# Mock logger to capture log messages
with patch("tasks.disable_segments_from_index_task.logger") as mock_logger:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Assert
assert result is None # Task should complete without returning a value
# Verify performance logging
mock_logger.info.assert_called()
log_calls = [call[0][0] for call in mock_logger.info.call_args_list]
performance_log = next((call for call in log_calls if "latency" in call), None)
assert performance_log is not None
assert "0.5" in performance_log # Should log the execution time
# Verify performance logging
performance_log = next((message for message in caplog.messages if "latency" in message), None)
assert performance_log is not None
assert "0.5" in performance_log # Should log the execution time
def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers: Session):
"""

View File

@ -10,6 +10,7 @@ All tests use the testcontainers infrastructure to ensure proper database isolat
and realistic testing scenarios with actual PostgreSQL and Redis instances.
"""
import logging
from unittest.mock import MagicMock, patch
import pytest
@ -543,7 +544,10 @@ class TestSendEmailCodeLoginMailTask:
redis_client.delete(cache_key)
def test_send_email_code_login_mail_task_error_handling_comprehensive(
self, db_session_with_containers: Session, mock_external_service_dependencies
self,
db_session_with_containers: Session,
mock_external_service_dependencies,
caplog: pytest.LogCaptureFixture,
):
"""
Test comprehensive error handling for email code login mail task.
@ -559,6 +563,7 @@ class TestSendEmailCodeLoginMailTask:
test_email = fake.email()
test_code = "123456"
test_language = "en-US"
caplog.set_level(logging.ERROR, logger="tasks.mail_email_code_login")
# Test different exception types
exception_types = [
@ -571,31 +576,26 @@ class TestSendEmailCodeLoginMailTask:
for error_name, exception in exception_types:
# Reset mocks for each test case
caplog.clear()
mock_email_service_instance = mock_external_service_dependencies["email_service_instance"]
mock_email_service_instance.reset_mock()
mock_email_service_instance.send_email.side_effect = exception
# Mock logging to capture error messages
with patch("tasks.mail_email_code_login.logger", autospec=True) as mock_logger:
# Act: Execute the task - it should handle the exception gracefully
send_email_code_login_mail_task(
language=test_language,
to=test_email,
code=test_code,
)
# Act: Execute the task - it should handle the exception gracefully
send_email_code_login_mail_task(
language=test_language,
to=test_email,
code=test_code,
)
# Assert: Verify error handling
# Verify email service was called (and failed)
mock_email_service_instance.send_email.assert_called_once()
# Assert: Verify error handling
# Verify email service was called (and failed)
mock_email_service_instance.send_email.assert_called_once()
# Verify error was logged
error_calls = [
call
for call in mock_logger.exception.call_args_list
if f"Send email code login mail to {test_email} failed" in str(call)
]
# Check if any exception call was made (the exact message format may vary)
assert mock_logger.exception.call_count >= 1, f"Error should be logged for {error_name}"
# Verify error was logged
assert f"Send email code login mail to {test_email} failed" in caplog.messages, (
f"Error should be logged for {error_name}"
)
# Reset side effect for next iteration
mock_email_service_instance.send_email.side_effect = None

View File

@ -11,6 +11,7 @@ and realistic testing scenarios with actual PostgreSQL and Redis instances.
"""
import json
import logging
import uuid
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
@ -295,7 +296,10 @@ class TestMailInviteMemberTask:
mock_email_service.send_email.assert_not_called()
def test_send_invite_member_mail_email_service_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies
self,
db_session_with_containers: Session,
mock_external_service_dependencies,
caplog: pytest.LogCaptureFixture,
):
"""
Test error handling when email service raises an exception.
@ -308,21 +312,19 @@ class TestMailInviteMemberTask:
# Arrange: Setup email service to raise exception
mock_email_service = mock_external_service_dependencies["email_service"]
mock_email_service.send_email.side_effect = Exception("Email service failed")
caplog.set_level(logging.ERROR, logger="tasks.mail_invite_member_task")
# Act & Assert: Execute task and verify exception is handled
with patch("tasks.mail_invite_member_task.logger", autospec=True) as mock_logger:
send_invite_member_mail_task(
language="en-US",
to="test@example.com",
token="test-token",
inviter_name="Test User",
workspace_name="Test Workspace",
)
send_invite_member_mail_task(
language="en-US",
to="test@example.com",
token="test-token",
inviter_name="Test User",
workspace_name="Test Workspace",
)
# Verify error was logged
mock_logger.exception.assert_called_once()
error_call = mock_logger.exception.call_args[0][0]
assert "Send invite member mail to %s failed" in error_call
# Verify error was logged
assert caplog.messages.count("Send invite member mail to test@example.com failed") == 1
def test_send_invite_member_mail_template_context_validation(
self, db_session_with_containers: Session, mock_external_service_dependencies

View File

@ -5,6 +5,7 @@ This module provides integration tests for email registration tasks
using TestContainers to ensure real database and service interactions.
"""
import logging
from unittest.mock import patch
import pytest
@ -68,7 +69,10 @@ class TestMailRegisterTask:
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_email_register_mail_task_exception_handling(
self, db_session_with_containers: Session, mock_mail_dependencies
self,
db_session_with_containers: Session,
mock_mail_dependencies,
caplog: pytest.LogCaptureFixture,
):
"""Test email registration task exception handling."""
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
@ -76,10 +80,11 @@ class TestMailRegisterTask:
fake = Faker()
to_email = fake.email()
code = fake.numerify("######")
caplog.set_level(logging.ERROR, logger="tasks.mail_register_task")
with patch("tasks.mail_register_task.logger", autospec=True) as mock_logger:
send_email_register_mail_task(language="en-US", to=to_email, code=code)
mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email)
send_email_register_mail_task(language="en-US", to=to_email, code=code)
assert caplog.messages.count(f"Send email register mail to {to_email} failed") == 1
def test_send_email_register_mail_task_when_account_exist_success(
self, db_session_with_containers: Session, mock_mail_dependencies
@ -121,7 +126,10 @@ class TestMailRegisterTask:
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_email_register_mail_task_when_account_exist_exception_handling(
self, db_session_with_containers: Session, mock_mail_dependencies
self,
db_session_with_containers: Session,
mock_mail_dependencies,
caplog: pytest.LogCaptureFixture,
):
"""Test account exist email task exception handling."""
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
@ -129,7 +137,8 @@ class TestMailRegisterTask:
fake = Faker()
to_email = fake.email()
account_name = fake.name()
caplog.set_level(logging.ERROR, logger="tasks.mail_register_task")
with patch("tasks.mail_register_task.logger", autospec=True) as mock_logger:
send_email_register_mail_task_when_account_exist(language="en-US", to=to_email, account_name=account_name)
mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email)
send_email_register_mail_task_when_account_exist(language="en-US", to=to_email, account_name=account_name)
assert caplog.messages.count(f"Send email register mail to {to_email} failed") == 1

View File

@ -8,7 +8,7 @@ This module tests the account activation mechanism including:
- Initial login after activation
"""
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -41,7 +41,7 @@ class TestActivateCheckApi:
}
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_valid_invitation_token(self, mock_get_invitation, app, mock_invitation):
def test_check_valid_invitation_token(self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock):
"""
Test checking valid invitation token.
@ -67,7 +67,9 @@ class TestActivateCheckApi:
assert response["data"]["email"] == "invitee@example.com"
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_valid_invitation_token_includes_account_status(self, mock_get_invitation, app, mock_invitation):
def test_check_valid_invitation_token_includes_account_status(
self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock
):
mock_account = MagicMock()
mock_account.status = AccountStatus.ACTIVE
mock_invitation["account"] = mock_account
@ -103,7 +105,9 @@ class TestActivateCheckApi:
assert response["is_valid"] is False
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_without_workspace_id(self, mock_get_invitation, app, mock_invitation):
def test_check_token_without_workspace_id(
self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock
):
"""
Test checking token without workspace ID.
@ -121,10 +125,10 @@ class TestActivateCheckApi:
# Assert
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token")
mock_get_invitation.assert_called_once_with(None, "invitee@example.com", "valid_token", session=ANY)
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_without_email(self, mock_get_invitation, app, mock_invitation):
def test_check_token_without_email(self, mock_get_invitation: MagicMock, app: Flask, mock_invitation):
"""
Test checking token without email parameter.
@ -142,10 +146,12 @@ class TestActivateCheckApi:
# Assert
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token")
mock_get_invitation.assert_called_once_with("workspace-123", None, "valid_token", session=ANY)
@patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback")
def test_check_token_normalizes_email_to_lowercase(self, mock_get_invitation, app, mock_invitation):
def test_check_token_normalizes_email_to_lowercase(
self, mock_get_invitation: MagicMock, app: Flask, mock_invitation: MagicMock
):
"""Ensure token validation uses lowercase emails."""
mock_get_invitation.return_value = mock_invitation
@ -156,7 +162,7 @@ class TestActivateCheckApi:
response = api.get()
assert response["is_valid"] is True
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token", session=ANY)
class TestActivateApi:
@ -554,7 +560,7 @@ class TestActivateApi:
response = api.post()
assert response["result"] == "success"
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token")
mock_get_invitation.assert_called_once_with("workspace-123", "Invitee@Example.com", "valid_token", session=ANY)
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.TenantService.create_tenant_member")
@ -593,7 +599,7 @@ class TestActivateApi:
mock_create_tenant_member.assert_called_once_with(
mock_invitation["tenant"], mock_account, mock_db.session, role=TenantAccountRole.ADMIN
)
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id)
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id, session=ANY)
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")
@patch("controllers.console.auth.activate.TenantService.create_tenant_member")
@ -628,5 +634,5 @@ class TestActivateApi:
assert response["result"] == "success"
mock_create_tenant_member.assert_not_called()
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id)
mock_switch_tenant.assert_called_once_with(mock_account, mock_invitation["tenant"].id, session=ANY)
mock_revoke_token.assert_called_once_with("workspace-123", "invitee@example.com", "valid_token")

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from pydantic import ValidationError
@ -25,6 +25,7 @@ def test_create_new_account_uses_requested_language(mock_create_account):
password="ValidPass123!",
interface_language="zh-Hans",
timezone="Asia/Shanghai",
session=ANY,
)

View File

@ -9,7 +9,7 @@ This module tests the email code login mechanism including:
"""
import base64
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -368,6 +368,7 @@ class TestEmailCodeLoginApi:
name="newuser@example.com",
interface_language="en-US",
timezone="Asia/Shanghai",
session=ANY,
)
@patch("controllers.console.wraps.db")

View File

@ -9,7 +9,7 @@ This module tests the core authentication endpoints including:
"""
import base64
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import ANY, MagicMock, Mock, patch
import pytest
from flask import Flask
@ -129,7 +129,7 @@ class TestLoginApi:
response = login_api.post()
# Assert
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None)
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", None, session=ANY)
mock_login.assert_called_once()
mock_reset_rate_limit.assert_called_once_with("test@example.com")
assert response.json["result"] == "success"
@ -184,7 +184,7 @@ class TestLoginApi:
response = login_api.post()
# Assert
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token")
mock_authenticate.assert_called_once_with("test@example.com", "ValidPass123!", "valid_token", session=ANY)
assert response.json["result"] == "success"
@patch("controllers.console.wraps.db")
@ -407,13 +407,13 @@ class TestLoginApi:
@patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit")
def test_login_retries_with_lowercase_email(
self,
mock_reset_rate_limit,
mock_login_service,
mock_get_tenants,
mock_add_rate_limit,
mock_authenticate,
mock_get_invitation,
mock_is_rate_limit,
mock_reset_rate_limit: MagicMock,
mock_login_service: MagicMock,
mock_get_tenants: MagicMock,
mock_add_rate_limit: MagicMock,
mock_authenticate: MagicMock,
mock_get_invitation: MagicMock,
mock_is_rate_limit: MagicMock,
mock_db,
app: Flask,
mock_account,
@ -435,8 +435,8 @@ class TestLoginApi:
assert response.json["result"] == "success"
assert mock_authenticate.call_args_list == [
(("Upper@Example.com", "ValidPass123!", None), {}),
(("upper@example.com", "ValidPass123!", None), {}),
(("Upper@Example.com", "ValidPass123!", None), {"session": ANY}),
(("upper@example.com", "ValidPass123!", None), {"session": ANY}),
]
mock_add_rate_limit.assert_not_called()
mock_reset_rate_limit.assert_called_once_with("upper@example.com")
@ -447,10 +447,10 @@ class TestLoginApi:
@patch("controllers.console.auth.login._get_account_with_case_fallback")
def test_email_code_login_logs_banned_account(
self,
mock_get_account,
mock_revoke_token,
mock_get_token_data,
mock_db,
mock_get_account: MagicMock,
mock_revoke_token: MagicMock,
mock_get_token_data: MagicMock,
mock_db: MagicMock,
app: Flask,
):
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
@ -491,7 +491,9 @@ class TestLogoutApi:
@patch("controllers.console.auth.login.AccountService.logout")
@patch("controllers.console.auth.login.flask_login.logout_user")
def test_successful_logout(self, mock_logout_user, mock_service_logout, app: Flask, mock_account):
def test_successful_logout(
self, mock_logout_user: MagicMock, mock_service_logout: MagicMock, app: Flask, mock_account
):
"""
Test successful logout flow.

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -66,8 +66,9 @@ def test_generate_account_registers_with_browser_timezone(
provider="github",
language="zh-Hans",
timezone="Asia/Shanghai",
session=ANY,
)
mock_link_account.assert_called_once_with("github", "github-123", account)
mock_link_account.assert_called_once_with("github", "github-123", account, session=ANY)
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
@ -97,8 +98,9 @@ def test_generate_account_prefers_state_language_over_accept_language(
provider="github",
language="zh-Hans",
timezone=None,
session=ANY,
)
mock_link_account.assert_called_once_with("github", "github-123", account)
mock_link_account.assert_called_once_with("github", "github-123", account, session=ANY)
@patch("controllers.console.auth.oauth.dify_config")

View File

@ -8,7 +8,7 @@ This module tests the token refresh mechanism including:
- Error handling for invalid tokens
"""
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -70,7 +70,7 @@ class TestRefreshTokenApi:
# Assert
mock_extract_token.assert_called_once()
mock_refresh_token.assert_called_once_with("valid_refresh_token")
mock_refresh_token.assert_called_once_with("valid_refresh_token", session=ANY)
assert response.json["result"] == "success"
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@ -191,7 +191,7 @@ class TestRefreshTokenApi:
# Assert
assert response.json["result"] == "success"
# Verify new token pair was generated
mock_refresh_token.assert_called_once_with("valid_refresh_token")
mock_refresh_token.assert_called_once_with("valid_refresh_token", session=ANY)
# In real implementation, cookies would be set with new values
assert mock_token_pair.access_token == "new_access_token"
assert mock_token_pair.refresh_token == "new_refresh_token"

View File

@ -94,7 +94,7 @@ class TestSystemFeatureApi:
"controllers.console.feature.current_account_with_tenant_optional",
return_value=(account, "tenant-123"),
)
system_features = SystemFeatureModel(is_allow_register=True)
system_features = SystemFeatureModel(is_allow_register=True, enable_learn_app=True)
get_system_features = mocker.patch(
"controllers.console.feature.FeatureService.get_system_features",
return_value=system_features,
@ -104,6 +104,7 @@ class TestSystemFeatureApi:
result = api.get()
assert result == system_features.model_dump()
assert result["enable_learn_app"] is True
current_account.assert_called_once_with()
get_system_features.assert_called_once_with(is_authenticated=True)

View File

@ -38,7 +38,7 @@ def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 1)
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
@ -48,7 +48,7 @@ def test_validate_init_password_already_setup(app: Flask, monkeypatch: pytest.Mo
def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
@ -60,7 +60,7 @@ def test_validate_init_password_wrong_password(app: Flask, monkeypatch: pytest.M
def test_validate_init_password_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda *, session: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"

View File

@ -1,6 +1,6 @@
import inspect
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -394,12 +394,12 @@ class TestChangeEmailReset:
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_new_email_before_update(
self,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_is_freeze: MagicMock,
mock_check_unique: MagicMock,
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_update_account: MagicMock,
mock_send_notify: MagicMock,
app: Flask,
):
current_user = _build_account("old@example.com", "acc3")
@ -424,9 +424,9 @@ class TestChangeEmailReset:
method(api, current_user)
mock_is_freeze.assert_called_once_with("new@example.com")
mock_check_unique.assert_called_once_with("new@example.com")
mock_check_unique.assert_called_once_with("new@example.com", session=ANY)
mock_revoke_token.assert_called_once_with("token-123")
mock_update_account.assert_called_once_with(current_user, email="new@example.com")
mock_update_account.assert_called_once_with(current_user, email="new@example.com", session=ANY)
mock_send_notify.assert_called_once_with(email="new@example.com")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@ -437,12 +437,12 @@ class TestChangeEmailReset:
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_reject_reset_when_token_phase_is_not_new_verified(
self,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_is_freeze: MagicMock,
mock_check_unique: MagicMock,
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_update_account: MagicMock,
mock_send_notify: MagicMock,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
@ -480,12 +480,12 @@ class TestChangeEmailReset:
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_reject_reset_when_token_email_differs_from_payload_new_email(
self,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_is_freeze: MagicMock,
mock_check_unique: MagicMock,
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_update_account: MagicMock,
mock_send_notify: MagicMock,
app: Flask,
):
"""A verified token for address A must not be replayed to change to address B."""
@ -523,12 +523,12 @@ class TestChangeEmailReset:
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_reject_reset_when_token_account_id_does_not_match_current_user(
self,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_is_freeze: MagicMock,
mock_check_unique: MagicMock,
mock_get_data: MagicMock,
mock_revoke_token: MagicMock,
mock_update_account: MagicMock,
mock_send_notify: MagicMock,
app: Flask,
):
from controllers.console.auth.error import InvalidTokenError
@ -575,9 +575,9 @@ class TestAccountServiceSendChangeEmailEmail:
@patch("services.account_service.AccountService.generate_change_email_token")
def test_should_bind_account_id_and_target_email_into_generated_token(
self,
mock_generate_token,
mock_rate_limiter,
mock_mail_task,
mock_generate_token: MagicMock,
mock_rate_limiter: MagicMock,
mock_mail_task: MagicMock,
):
mock_rate_limiter.is_rate_limited.return_value = False
mock_generate_token.return_value = "the-token"
@ -665,7 +665,7 @@ class TestAccountDeletionFeedback:
class TestCheckEmailUnique:
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, app: Flask):
def test_should_normalize_email(self, mock_is_freeze: MagicMock, mock_check_unique: MagicMock, app: Flask):
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
@ -680,7 +680,7 @@ class TestCheckEmailUnique:
assert response == {"result": "success"}
mock_is_freeze.assert_called_once_with("case@test.com")
mock_check_unique.assert_called_once_with("case@test.com")
mock_check_unique.assert_called_once_with("case@test.com", session=ANY)
def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup():

View File

@ -8,7 +8,7 @@ handler tests use inspect.unwrap() to bypass them and focus on business logic.
import inspect
from datetime import datetime
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pytest
from flask import Flask
@ -115,7 +115,7 @@ class TestEnterpriseWorkspace:
assert result["message"] == "enterprise workspace created."
assert result["tenant"]["id"] == "tenant-id"
assert result["tenant"]["name"] == "My Workspace"
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True, session=ANY)
mock_tenant_svc.create_tenant_member.assert_called_once_with(
mock_tenant, mock_account, mock_db.session, role="owner"
)
@ -183,5 +183,5 @@ class TestEnterpriseWorkspaceNoOwnerEmail:
assert result["tenant"]["id"] == "tenant-id"
assert result["tenant"]["encrypt_public_key"] == "pub-key"
assert result["tenant"]["custom_config"] == {}
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True)
mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True, session=ANY)
mock_event.send.assert_called_once_with(mock_tenant)

View File

@ -29,6 +29,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueuePingEvent,
QueueReasoningChunkEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
QueueTextChunkEvent,
@ -46,6 +47,7 @@ from core.app.entities.task_entities import (
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
ReasoningChunkStreamResponse,
)
from core.base.tts.app_generator_tts_publisher import AudioTrunk
from core.workflow.system_variables import build_system_variables
@ -196,6 +198,42 @@ class TestAdvancedChatGenerateTaskPipeline:
assert pipeline._task_state.answer == "hi"
assert responses
def test_handle_reasoning_chunk_event_emits_on_nonempty(self):
pipeline = _make_pipeline()
event = QueueReasoningChunkEvent(reasoning="pondering", from_node_id="llm-1", is_final=False)
responses = list(pipeline._handle_reasoning_chunk_event(event))
assert len(responses) == 1
response = responses[0]
assert isinstance(response, ReasoningChunkStreamResponse)
assert response.data.message_id == pipeline._message_id
assert response.data.reasoning == "pondering"
assert response.data.node_id == "llm-1"
assert response.data.is_final is False
# reasoning never touches the answer stream
assert pipeline._task_state.answer == ""
def test_handle_reasoning_chunk_event_drops_empty_nonfinal(self):
pipeline = _make_pipeline()
event = QueueReasoningChunkEvent(reasoning="", from_node_id="llm-1", is_final=False)
responses = list(pipeline._handle_reasoning_chunk_event(event))
assert responses == []
def test_handle_reasoning_chunk_event_emits_empty_final_marker(self):
pipeline = _make_pipeline()
event = QueueReasoningChunkEvent(reasoning="", from_node_id="llm-1", is_final=True)
responses = list(pipeline._handle_reasoning_chunk_event(event))
assert len(responses) == 1
response = responses[0]
assert isinstance(response, ReasoningChunkStreamResponse)
assert response.data.reasoning == ""
assert response.data.is_final is True
def test_listen_audio_msg_returns_audio_stream(self):
pipeline = _make_pipeline()
publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data"))
@ -319,6 +357,43 @@ class TestAdvancedChatGenerateTaskPipeline:
assert responses == ["done"]
assert pipeline._recorded_files
def test_handle_node_succeeded_event_records_llm_reasoning(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: []
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
pipeline._save_output_for_event = lambda event, node_execution_id: None
event = SimpleNamespace(
node_type=BuiltinNodeTypes.LLM,
outputs={"reasoning_content": "first pass "},
node_execution_id="exec",
node_id="llm-1",
)
list(pipeline._handle_node_succeeded_event(event))
assert pipeline._task_state.metadata.reasoning == {"llm-1": "first pass "}
def test_handle_node_succeeded_event_accumulates_reasoning_across_passes(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: []
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"
pipeline._save_output_for_event = lambda event, node_execution_id: None
def _llm_event(reasoning: str):
return SimpleNamespace(
node_type=BuiltinNodeTypes.LLM,
outputs={"reasoning_content": reasoning},
node_execution_id="exec",
node_id="llm-1",
)
# Same node id across iteration/loop passes must accumulate, not overwrite.
list(pipeline._handle_node_succeeded_event(_llm_event("pass one ")))
list(pipeline._handle_node_succeeded_event(_llm_event("pass two")))
assert pipeline._task_state.metadata.reasoning == {"llm-1": "pass one pass two"}
def test_iteration_and_loop_handlers(self):
pipeline = _make_pipeline()
pipeline._workflow_run_id = "run-id"

View File

@ -16,6 +16,7 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeSucceededEvent,
QueueReasoningChunkEvent,
QueueTextChunkEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
@ -34,6 +35,7 @@ from graphon.graph_events import (
NodeRunHumanInputFormFilledEvent,
NodeRunIterationSucceededEvent,
NodeRunLoopFailedEvent,
NodeRunReasoningChunkEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
@ -395,6 +397,17 @@ class TestWorkflowBasedAppRunner:
is_final=False,
),
)
runner._handle_event(
workflow_entry,
NodeRunReasoningChunkEvent(
id="exec",
node_id="node",
node_type=BuiltinNodeTypes.LLM,
selector=["node", "reasoning_content"],
chunk="thinking",
is_final=False,
),
)
runner._handle_event(
workflow_entry,
NodeRunAgentLogEvent(
@ -442,6 +455,7 @@ class TestWorkflowBasedAppRunner:
)
assert any(isinstance(event, QueueTextChunkEvent) for event in published)
assert any(isinstance(event, QueueReasoningChunkEvent) for event in published)
assert any(isinstance(event, QueueAgentLogEvent) for event in published)
assert any(isinstance(event, QueueIterationCompletedEvent) for event in published)
assert any(isinstance(event, QueueLoopCompletedEvent) for event in published)

View File

@ -26,6 +26,7 @@ from core.app.entities.queue_entities import (
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueuePingEvent,
QueueReasoningChunkEvent,
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
@ -40,6 +41,7 @@ from core.app.entities.task_entities import (
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
ReasoningChunkStreamResponse,
WorkflowAppPausedBlockingResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
@ -265,6 +267,41 @@ class TestWorkflowGenerateTaskPipeline:
assert responses[0].data.text == "hi"
assert published == [queue_message]
def test_handle_reasoning_chunk_event_emits_on_nonempty(self):
pipeline = _make_pipeline()
event = QueueReasoningChunkEvent(reasoning="pondering", from_node_id="llm-1", is_final=False)
responses = list(pipeline._handle_reasoning_chunk_event(event))
assert len(responses) == 1
response = responses[0]
assert isinstance(response, ReasoningChunkStreamResponse)
# workflow runs have no message, so the id is omitted
assert response.data.message_id is None
assert response.data.reasoning == "pondering"
assert response.data.node_id == "llm-1"
assert response.data.is_final is False
def test_handle_reasoning_chunk_event_drops_empty_nonfinal(self):
pipeline = _make_pipeline()
event = QueueReasoningChunkEvent(reasoning="", from_node_id="llm-1", is_final=False)
responses = list(pipeline._handle_reasoning_chunk_event(event))
assert responses == []
def test_handle_reasoning_chunk_event_emits_empty_final_marker(self):
pipeline = _make_pipeline()
event = QueueReasoningChunkEvent(reasoning="", from_node_id="llm-1", is_final=True)
responses = list(pipeline._handle_reasoning_chunk_event(event))
assert len(responses) == 1
response = responses[0]
assert isinstance(response, ReasoningChunkStreamResponse)
assert response.data.reasoning == ""
assert response.data.is_final is True
def test_dispatch_event_handles_node_failed(self):
pipeline = _make_pipeline()
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done"

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import logging
from contextlib import contextmanager
from types import SimpleNamespace
from typing import Any
@ -1704,13 +1705,14 @@ def test_get_specific_provider_credential_decrypts_and_obfuscates_credentials()
assert credentials == {"openai_api_key": "raw-secret", "region": "us"}
def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None:
def test_get_specific_provider_credential_logs_when_decrypt_fails(caplog: pytest.LogCaptureFixture) -> None:
configuration = _build_provider_configuration()
configuration.provider.provider_credential_schema = _build_secret_provider_schema()
session = Mock()
session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(
encrypted_config='{"openai_api_key":"enc-secret"}'
)
caplog.set_level(logging.ERROR, logger="core.entities.provider_configuration")
with _patched_session(session):
with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None):
@ -1718,16 +1720,15 @@ def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None:
"core.entities.provider_configuration.encrypter.decrypt_token",
side_effect=RuntimeError("boom"),
):
with patch("core.entities.provider_configuration.logger.exception") as mock_logger:
with patch.object(
ProviderConfiguration,
"obfuscated_credentials",
side_effect=lambda credentials, credential_form_schemas: credentials,
):
credentials = configuration._get_specific_provider_credential("cred-1")
with patch.object(
ProviderConfiguration,
"obfuscated_credentials",
side_effect=lambda credentials, credential_form_schemas: credentials,
):
credentials = configuration._get_specific_provider_credential("cred-1")
assert credentials == {"openai_api_key": "enc-secret"}
mock_logger.assert_called_once()
assert caplog.messages.count("Failed to decrypt credential secret variable openai_api_key") == 1
def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None:
@ -1831,7 +1832,7 @@ def test_switch_active_provider_credential_rolls_back_on_error() -> None:
session.rollback.assert_called_once()
def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None:
def test_get_specific_custom_model_credential_logs_when_decrypt_fails(caplog: pytest.LogCaptureFixture) -> None:
configuration = _build_provider_configuration()
configuration.provider.model_credential_schema = _build_secret_model_schema()
session = Mock()
@ -1840,19 +1841,19 @@ def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None:
credential_name="Main",
encrypted_config='{"openai_api_key":"enc-secret"}',
)
caplog.set_level(logging.ERROR, logger="core.entities.provider_configuration")
with _patched_session(session):
with patch("core.entities.provider_configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")):
with patch("core.entities.provider_configuration.logger.exception") as mock_logger:
with patch.object(
ProviderConfiguration,
"obfuscated_credentials",
side_effect=lambda credentials, credential_form_schemas: credentials,
):
result = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1")
with patch.object(
ProviderConfiguration,
"obfuscated_credentials",
side_effect=lambda credentials, credential_form_schemas: credentials,
):
result = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1")
assert result["credentials"] == {"openai_api_key": "enc-secret"}
mock_logger.assert_called_once()
assert caplog.messages.count("Failed to decrypt model credential secret variable openai_api_key") == 1
def test_validate_custom_model_credentials_handles_invalid_original_json() -> None:

View File

@ -407,7 +407,7 @@ class TestCacheEmbeddingDocuments:
assert len(calls[1].kwargs["texts"]) == 10
assert len(calls[2].kwargs["texts"]) == 5
def test_embed_documents_nan_handling(self, mock_model_instance, caplog):
def test_embed_documents_nan_handling(self, mock_model_instance, caplog: pytest.LogCaptureFixture):
"""Test handling of NaN values in embeddings.
Verifies:

View File

@ -385,7 +385,7 @@ class TestParagraphIndexProcessor:
with pytest.raises(ValueError, match="model_name and model_provider_name"):
ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True})
def test_generate_summary_text_only_flow(self, caplog) -> None:
def test_generate_summary_text_only_flow(self, caplog: pytest.LogCaptureFixture) -> None:
model_instance = Mock()
model_instance.credentials = {"k": "v"}
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[])
@ -459,7 +459,7 @@ class TestParagraphIndexProcessor:
assert summary == "vision summary"
mock_extract_text.assert_not_called()
def test_generate_summary_fallbacks_for_prompt_and_result_types(self, caplog) -> None:
def test_generate_summary_fallbacks_for_prompt_and_result_types(self, caplog: pytest.LogCaptureFixture) -> None:
model_instance = Mock()
model_instance.credentials = {"k": "v"}
model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(
@ -503,7 +503,7 @@ class TestParagraphIndexProcessor:
"Failed to convert image file to prompt message content" in record.message for record in caplog.records
)
def test_extract_images_from_text_handles_patterns_and_build_errors(self, caplog) -> None:
def test_extract_images_from_text_handles_patterns_and_build_errors(self, caplog: pytest.LogCaptureFixture) -> None:
text = (
"![img](/files/11111111-1111-1111-1111-111111111111/image-preview) "
"![img2](/files/22222222-2222-2222-2222-222222222222/file-preview) "
@ -554,7 +554,7 @@ class TestParagraphIndexProcessor:
session.scalars.return_value = scalars_result
assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here", session) == []
def test_extract_images_from_text_logs_when_build_fails(self, caplog) -> None:
def test_extract_images_from_text_logs_when_build_fails(self, caplog: pytest.LogCaptureFixture) -> None:
text = "![img](/files/11111111-1111-1111-1111-111111111111/image-preview)"
image_upload = SimpleNamespace(
id="11111111-1111-1111-1111-111111111111",
@ -583,7 +583,7 @@ class TestParagraphIndexProcessor:
assert files == []
assert sum(1 for r in caplog.records if r.levelno == logging.WARNING) == 1
def test_extract_images_from_segment_attachments(self, caplog) -> None:
def test_extract_images_from_segment_attachments(self, caplog: pytest.LogCaptureFixture) -> None:
image_upload = SimpleNamespace(
id="file-1",
name="image",

View File

@ -351,7 +351,9 @@ class TestQAIndexProcessor:
assert all_qa_documents[0].metadata["answer"] == "A test."
assert all_qa_documents[1].metadata["answer"] == "Coverage."
def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app, caplog) -> None:
def test_format_qa_document_logs_errors(
self, processor: QAIndexProcessor, fake_flask_app, caplog: pytest.LogCaptureFixture
) -> None:
all_qa_documents: list[Document] = []
source_document = Document(page_content="source text", metadata={"origin": "doc-1"})

View File

@ -4,6 +4,8 @@ from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
@ -122,7 +124,7 @@ def test_precheck_ignores_non_quota_node() -> None:
mock_check.assert_not_called()
def test_quota_error_is_handled_in_layer(caplog) -> None:
def test_quota_error_is_handled_in_layer(caplog: pytest.LogCaptureFixture) -> None:
layer = LLMQuotaLayer(tenant_id="tenant-id")
stop_event = threading.Event()
layer.command_channel = MagicMock()

View File

@ -7,6 +7,8 @@ from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from configs.enterprise import EnterpriseTelemetryConfig
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
from enterprise.telemetry.exporter import EnterpriseExporter, _datetime_to_ns, _parse_otlp_headers
@ -533,7 +535,7 @@ def test_export_span_cross_workflow_parent_context() -> None:
assert kwargs["context"] is not None
def test_export_span_logs_exception_on_error(caplog) -> None:
def test_export_span_logs_exception_on_error(caplog: pytest.LogCaptureFixture) -> None:
"""If the span block raises, the exception is logged and context is still cleared."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
@ -546,7 +548,7 @@ def test_export_span_logs_exception_on_error(caplog) -> None:
assert "bad.span" in caplog.text
def test_export_span_invalid_trace_correlation_logs_warning(caplog) -> None:
def test_export_span_invalid_trace_correlation_logs_warning(caplog: pytest.LogCaptureFixture) -> None:
"""Invalid UUID for trace_correlation_override triggers a warning log."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()

View File

@ -4,6 +4,7 @@ from types import SimpleNamespace
from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy import create_engine, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@ -122,7 +123,9 @@ def test_message_created_paid_credit_accounting_uses_paid_pool() -> None:
)
def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted(caplog) -> None:
def test_capped_credit_pool_accounting_skips_exhaustion_warning_when_full_amount_is_deducted(
caplog: pytest.LogCaptureFixture,
) -> None:
with patch(
"services.credit_pool_service.CreditPoolService.deduct_credits_capped",
return_value=3,

View File

@ -56,18 +56,19 @@ def mock_response_receiver(monkeypatch: pytest.MonkeyPatch) -> mock.Mock:
return mock_log_request_finished
@pytest.fixture
def mock_logger(monkeypatch: pytest.MonkeyPatch) -> logging.Logger:
_logger = mock.MagicMock(spec=logging.Logger)
monkeypatch.setattr(ext_request_logging, "logger", _logger)
return _logger
@pytest.fixture
def enable_request_logging(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", True)
def _captured_records(caplog: pytest.LogCaptureFixture, level: int) -> list[logging.LogRecord]:
return [
record
for record in caplog.records
if record.name == ext_request_logging.logger.name and record.levelno == level
]
class TestRequestLoggingExtension:
def test_receiver_should_not_be_invoked_if_configuration_is_disabled(
self,
@ -108,67 +109,74 @@ class TestRequestLoggingExtension:
class TestLoggingLevel:
@pytest.mark.usefixtures("enable_request_logging")
def test_logging_should_be_skipped_if_level_is_above_debug(self, enable_request_logging, mock_logger):
mock_logger.isEnabledFor.return_value = False
def test_logging_should_be_skipped_if_level_is_above_debug(
self, enable_request_logging, caplog: pytest.LogCaptureFixture
):
caplog.set_level(logging.INFO, logger=ext_request_logging.logger.name)
app = _get_test_app()
init_app(app)
with app.test_client() as client:
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
mock_logger.debug.assert_not_called()
assert not _captured_records(caplog, logging.DEBUG)
class TestRequestReceiverLogging:
@pytest.mark.usefixtures("enable_request_logging")
def test_non_json_request(self, enable_request_logging, mock_logger, mock_response_receiver):
mock_logger.isEnabledFor.return_value = True
def test_non_json_request(self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver):
caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name)
app = _get_test_app()
init_app(app)
with app.test_client() as client:
client.post("/", data="plain text")
assert mock_logger.debug.call_count == 1
call_args = mock_logger.debug.call_args[0]
assert "Received Request" in call_args[0]
assert call_args[1] == "POST"
assert call_args[2] == "/"
assert "Request Body" not in call_args[0]
debug_records = _captured_records(caplog, logging.DEBUG)
assert len(debug_records) == 1
record = debug_records[0]
assert "Received Request" in record.msg
assert record.args == ("POST", "/")
assert "Request Body" not in record.msg
@pytest.mark.usefixtures("enable_request_logging")
def test_json_request(self, enable_request_logging, mock_logger, mock_response_receiver):
mock_logger.isEnabledFor.return_value = True
def test_json_request(self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver):
caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name)
app = _get_test_app()
init_app(app)
with app.test_client() as client:
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
assert mock_logger.debug.call_count == 1
call_args = mock_logger.debug.call_args[0]
assert "Received Request" in call_args[0]
assert "Request Body" in call_args[0]
assert call_args[1] == "POST"
assert call_args[2] == "/"
assert _KEY_NEEDLE in call_args[3]
debug_records = _captured_records(caplog, logging.DEBUG)
assert len(debug_records) == 1
record = debug_records[0]
assert "Received Request" in record.msg
assert "Request Body" in record.msg
assert record.args[0] == "POST"
assert record.args[1] == "/"
assert _KEY_NEEDLE in record.args[2]
@pytest.mark.usefixtures("enable_request_logging")
def test_json_request_with_empty_body(self, enable_request_logging, mock_logger, mock_response_receiver):
mock_logger.isEnabledFor.return_value = True
def test_json_request_with_empty_body(
self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver
):
caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name)
app = _get_test_app()
init_app(app)
with app.test_client() as client:
client.post("/", headers={"Content-Type": "application/json"})
assert mock_logger.debug.call_count == 1
call_args = mock_logger.debug.call_args[0]
assert "Received Request" in call_args[0]
assert "Request Body" not in call_args[0]
assert call_args[1] == "POST"
assert call_args[2] == "/"
debug_records = _captured_records(caplog, logging.DEBUG)
assert len(debug_records) == 1
record = debug_records[0]
assert "Received Request" in record.msg
assert "Request Body" not in record.msg
assert record.args == ("POST", "/")
@pytest.mark.usefixtures("enable_request_logging")
def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver):
mock_logger.isEnabledFor.return_value = True
def test_json_request_with_invalid_json_as_body(
self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver
):
caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name)
app = _get_test_app()
init_app(app)
@ -178,50 +186,53 @@ class TestRequestReceiverLogging:
headers={"Content-Type": "application/json"},
data="{",
)
assert mock_logger.debug.call_count == 0
assert mock_logger.exception.call_count == 1
exception_call_args = mock_logger.exception.call_args[0]
assert exception_call_args[0] == "Failed to parse JSON request"
assert not _captured_records(caplog, logging.DEBUG)
error_records = _captured_records(caplog, logging.ERROR)
assert len(error_records) == 1
assert error_records[0].message == "Failed to parse JSON request"
class TestResponseReceiverLogging:
@pytest.mark.usefixtures("enable_request_logging")
def test_non_json_response(self, enable_request_logging, mock_logger):
mock_logger.isEnabledFor.return_value = True
def test_non_json_response(self, enable_request_logging, caplog: pytest.LogCaptureFixture):
caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name)
app = _get_test_app()
response = Response(
"OK",
headers={"Content-Type": "text/plain"},
)
_log_request_finished(app, response)
assert mock_logger.debug.call_count == 1
call_args = mock_logger.debug.call_args[0]
assert "Response" in call_args[0]
assert "200" in call_args[1]
assert call_args[2] == "text/plain"
assert "Response Body" not in call_args[0]
debug_records = _captured_records(caplog, logging.DEBUG)
assert len(debug_records) == 1
record = debug_records[0]
assert "Response" in record.msg
assert "200" in record.args[0]
assert record.args[1] == "text/plain"
assert "Response Body" not in record.msg
@pytest.mark.usefixtures("enable_request_logging")
def test_json_response(self, enable_request_logging, mock_logger, mock_response_receiver):
mock_logger.isEnabledFor.return_value = True
def test_json_response(self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver):
caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name)
app = _get_test_app()
response = Response(
json.dumps({_KEY_NEEDLE: _VALUE_NEEDLE}),
headers={"Content-Type": "application/json"},
)
_log_request_finished(app, response)
assert mock_logger.debug.call_count == 1
call_args = mock_logger.debug.call_args[0]
assert "Response" in call_args[0]
assert "Response Body" in call_args[0]
assert "200" in call_args[1]
assert call_args[2] == "application/json"
assert _KEY_NEEDLE in call_args[3]
debug_records = _captured_records(caplog, logging.DEBUG)
assert len(debug_records) == 1
record = debug_records[0]
assert "Response" in record.msg
assert "Response Body" in record.msg
assert "200" in record.args[0]
assert record.args[1] == "application/json"
assert _KEY_NEEDLE in record.args[2]
@pytest.mark.usefixtures("enable_request_logging")
def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver):
mock_logger.isEnabledFor.return_value = True
def test_json_request_with_invalid_json_as_body(
self, enable_request_logging, caplog: pytest.LogCaptureFixture, mock_response_receiver
):
caplog.set_level(logging.DEBUG, logger=ext_request_logging.logger.name)
app = _get_test_app()
response = Response(
@ -229,11 +240,10 @@ class TestResponseReceiverLogging:
headers={"Content-Type": "application/json"},
)
_log_request_finished(app, response)
assert mock_logger.debug.call_count == 0
assert mock_logger.exception.call_count == 1
exception_call_args = mock_logger.exception.call_args[0]
assert exception_call_args[0] == "Failed to parse JSON response"
assert not _captured_records(caplog, logging.DEBUG)
error_records = _captured_records(caplog, logging.ERROR)
assert len(error_records) == 1
assert error_records[0].message == "Failed to parse JSON response"
class TestResponseUnmodified:
@ -267,7 +277,7 @@ class TestResponseUnmodified:
class TestRequestFinishedInfoAccessLine:
def test_info_access_log_includes_method_path_status_duration_trace_id(
self, monkeypatch: pytest.MonkeyPatch, caplog
self, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
):
"""Ensure INFO access line contains expected fields with computed duration and trace id."""
app = _get_test_app()

View File

@ -5,6 +5,7 @@ This module covers the pre-uninstall plugin hook behavior:
- API failure: soft-fail (logs and does not re-raise)
"""
import logging
from unittest.mock import patch
import pytest
@ -43,18 +44,16 @@ class TestTryPreUninstallPlugin:
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
def test_try_pre_uninstall_plugin_http_error_soft_fails(self):
def test_try_pre_uninstall_plugin_http_error_soft_fails(self, caplog: pytest.LogCaptureFixture):
body = PreUninstallPluginRequest(
tenant_id="tenant-456",
plugin_unique_identifier="com.example.other_plugin",
)
caplog.set_level(logging.ERROR, logger="services.enterprise.plugin_manager_service")
with (
patch(
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
) as mock_send_request,
patch("services.enterprise.plugin_manager_service.logger") as mock_logger,
):
with patch(
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
) as mock_send_request:
mock_send_request.side_effect = HTTPStatusError(
"502 Bad Gateway",
request=None,
@ -69,20 +68,22 @@ class TestTryPreUninstallPlugin:
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
mock_logger.exception.assert_called_once()
assert len(caplog.records) == 1
assert caplog.messages[0] == (
"failed to perform pre uninstall plugin hook. tenant_id: tenant-456, "
"plugin_unique_identifier: com.example.other_plugin"
)
def test_try_pre_uninstall_plugin_generic_exception_soft_fails(self):
def test_try_pre_uninstall_plugin_generic_exception_soft_fails(self, caplog: pytest.LogCaptureFixture):
body = PreUninstallPluginRequest(
tenant_id="tenant-789",
plugin_unique_identifier="com.example.failing_plugin",
)
caplog.set_level(logging.ERROR, logger="services.enterprise.plugin_manager_service")
with (
patch(
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
) as mock_send_request,
patch("services.enterprise.plugin_manager_service.logger") as mock_logger,
):
with patch(
"services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request"
) as mock_send_request:
mock_send_request.side_effect = ConnectionError("network unreachable")
PluginManagerService.try_pre_uninstall_plugin(body)
@ -93,7 +94,11 @@ class TestTryPreUninstallPlugin:
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
mock_logger.exception.assert_called_once()
assert len(caplog.records) == 1
assert caplog.messages[0] == (
"failed to perform pre uninstall plugin hook. tenant_id: tenant-789, "
"plugin_unique_identifier: com.example.failing_plugin"
)
class TestCheckCredentialPolicyCompliance:

View File

@ -42,6 +42,16 @@ class TestBuildInRecommendAppRetrieval:
mock_fetch.assert_called_once_with("en-US")
assert result == {"apps": []}
@patch("services.recommend_app.buildin.buildin_retrieval.DatabaseRecommendAppRetrieval")
def test_get_learn_dify_apps_delegates_to_database(self, mock_database_retrieval):
expected = {"recommended_apps": [{"id": "learn-dify-app"}]}
mock_database_retrieval.fetch_learn_dify_apps_from_db.return_value = expected
result = BuildInRecommendAppRetrieval().get_learn_dify_apps("en-US")
assert result == expected
mock_database_retrieval.fetch_learn_dify_apps_from_db.assert_called_once_with("en-US")
def test_get_recommend_app_detail_delegates(self):
with patch.object(
BuildInRecommendAppRetrieval,

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from services.recommend_app.recommend_app_type import RecommendAppType
from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval
@ -58,6 +59,32 @@ class TestRemoteRecommendAppRetrieval:
result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US")
assert result == {"recommended_apps": [{"id": "builtin"}]}
@patch.object(
RemoteRecommendAppRetrieval,
"fetch_learn_dify_apps_from_dify_official",
return_value={"recommended_apps": [{"id": "learn-dify-app"}]},
)
def test_get_learn_dify_apps_success(self, mock_fetch):
result = RemoteRecommendAppRetrieval().get_learn_dify_apps("en-US")
assert result == {"recommended_apps": [{"id": "learn-dify-app"}]}
mock_fetch.assert_called_once_with("en-US")
@patch(
"services.recommend_app.remote.remote_retrieval.DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db",
return_value={"recommended_apps": [{"id": "db-fallback"}]},
)
@patch.object(
RemoteRecommendAppRetrieval,
"fetch_learn_dify_apps_from_dify_official",
side_effect=ValueError("server error"),
)
def test_get_learn_dify_apps_falls_back_to_database_on_error(self, mock_fetch, mock_database):
result = RemoteRecommendAppRetrieval().get_learn_dify_apps("en-US")
assert result == {"recommended_apps": [{"id": "db-fallback"}]}
mock_database.assert_called_once_with("en-US")
class TestFetchFromDifyOfficial:
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@ -118,3 +145,84 @@ class TestFetchFromDifyOfficial:
result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
assert "categories" not in result
assert mock_get.call_args.kwargs["headers"] == {}
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_apps_forwards_request_origin_header(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_config.CONSOLE_WEB_URL = "https://saas.dify.dev"
mock_response = MagicMock(status_code=200)
mock_response.json.return_value = {"recommended_apps": []}
mock_get.return_value = mock_response
flask_app = Flask(__name__)
with flask_app.test_request_context(headers={"Origin": "https://cloud.example.com"}):
RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
assert mock_get.call_args.kwargs["headers"] == {"Origin": "https://cloud.example.com"}
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_apps_falls_back_to_console_web_url_origin(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_config.CONSOLE_WEB_URL = "https://saas.dify.dev/console"
mock_response = MagicMock(status_code=200)
mock_response.json.return_value = {"recommended_apps": []}
mock_get.return_value = mock_response
flask_app = Flask(__name__)
with flask_app.test_request_context():
RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
assert mock_get.call_args.kwargs["headers"] == {"Origin": "https://saas.dify.dev/console"}
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_apps_falls_back_to_console_web_url_without_request_context(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_config.CONSOLE_WEB_URL = "http://localhost:3000/console"
mock_response = MagicMock(status_code=200)
mock_response.json.return_value = {"recommended_apps": []}
mock_get.return_value = mock_response
RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
assert mock_get.call_args.kwargs["headers"] == {"Origin": "http://localhost:3000/console"}
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_apps_uses_console_web_url_without_scheme(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_config.CONSOLE_WEB_URL = "saas.dify.dev"
mock_response = MagicMock(status_code=200)
mock_response.json.return_value = {"recommended_apps": []}
mock_get.return_value = mock_response
flask_app = Flask(__name__)
with flask_app.test_request_context():
RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US")
assert mock_get.call_args.kwargs["headers"] == {"Origin": "saas.dify.dev"}
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_learn_dify_apps_returns_json_on_200(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_response = MagicMock(status_code=200)
mock_response.json.return_value = {"recommended_apps": [{"id": "learn-dify-app"}]}
mock_get.return_value = mock_response
result = RemoteRecommendAppRetrieval.fetch_learn_dify_apps_from_dify_official("en-US")
assert result == {"recommended_apps": [{"id": "learn-dify-app"}]}
assert mock_get.call_args.args[0] == "https://example.com/apps/learn-dify?language=en-US"
@patch("services.recommend_app.remote.remote_retrieval.dify_config")
@patch("services.recommend_app.remote.remote_retrieval.httpx.get")
def test_learn_dify_apps_raises_on_non_200(self, mock_get, mock_config):
mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com"
mock_get.return_value = MagicMock(status_code=500)
with pytest.raises(ValueError, match="fetch learn dify apps failed"):
RemoteRecommendAppRetrieval.fetch_learn_dify_apps_from_dify_official("en-US")

Some files were not shown because too many files have changed in this diff Show More