mirror of
https://github.com/langgenius/dify.git
synced 2026-06-26 14:51:13 +08:00
Merge remote-tracking branch 'upstream/main' into feat/dataset-api-key-scope
This commit is contained in:
commit
fd70b21d3b
@ -102,11 +102,11 @@ describe('ComponentName', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// Props tests (REQUIRED)
|
||||
// Props tests (REQUIRED when props change observable behavior)
|
||||
describe('Props', () => {
|
||||
it('should apply custom className', () => {
|
||||
render(<Component className="custom" />)
|
||||
expect(screen.getByRole('button')).toHaveClass('custom')
|
||||
it('should disable the action when disabled', () => {
|
||||
render(<Component disabled />)
|
||||
expect(screen.getByRole('button')).toBeDisabled()
|
||||
})
|
||||
})
|
||||
|
||||
@ -220,6 +220,7 @@ Every test should clearly separate:
|
||||
### 2. Black-Box Testing
|
||||
|
||||
- Test observable behavior, not implementation details
|
||||
- Test product contracts, not cosmetic implementation. Do not add or expand unit tests only to lock pure style classes, spacing, colors, backgrounds, or layout micro-adjustments. Cover visual-only fixes with browser/manual verification, screenshots, or E2E/visual checks when risk justifies it. Add unit tests only when the change affects user-observable behavior, accessibility semantics, state, data flow, routing, or a stable component API contract.
|
||||
- Use semantic queries (`getByRole` with accessible `name`, `getByLabelText`, `getByPlaceholderText`, `getByText`, and scoped `within(...)`)
|
||||
- Treat `getByTestId` as a last resort. If a control cannot be found by role/name, label, landmark, or dialog scope, fix the component accessibility first instead of adding or relying on `data-testid`.
|
||||
- Remove production `data-testid` attributes when semantic selectors can cover the behavior. Keep them only for non-visual mocked boundaries, editor/browser shims such as Monaco, canvas/chart output, or third-party widgets with no accessible DOM in the test environment.
|
||||
@ -273,7 +274,7 @@ it('should disable input when isReadOnly is true')
|
||||
### Always Required (All Components)
|
||||
|
||||
1. **Rendering**: Component renders without crashing
|
||||
1. **Props**: Required props, optional props, default values
|
||||
1. **Props**: Required props, optional props, default values that change observable behavior. Do not test pass-through styling props such as `className` unless they are an explicit, stable component API whose absence would break a real integration contract.
|
||||
1. **Edge Cases**: null, undefined, empty values, boundary conditions
|
||||
|
||||
### Conditional (When Present)
|
||||
|
||||
@ -1,65 +1,59 @@
|
||||
---
|
||||
name: how-to-write-component
|
||||
description: React/TypeScript component style guide. Use when writing, refactoring, or reviewing React components, especially around abstraction choices, props typing, state boundaries, shared local state with Jotai atoms, API types, query/mutation contracts, navigation, memoization, wrappers, and empty-state handling.
|
||||
description: Use when writing, refactoring, or reviewing React/TypeScript components in Dify web, especially decisions about component ownership, props/types, URL/query state, Jotai state, async state, generated API contracts, queries/mutations, overlays, effects, navigation, performance, and empty states.
|
||||
---
|
||||
|
||||
# How To Write A Component
|
||||
|
||||
Use this as the decision guide for React/TypeScript component structure. Existing code is reference material, not automatic precedent; when it conflicts with these rules, adapt the approach instead of reproducing the violation.
|
||||
Use this as the component decision guide for Dify web. Existing code is reference material, not automatic precedent; if touched code violates these rules, adapt it and fix equivalent patterns in the same feature branch.
|
||||
|
||||
## First Decisions
|
||||
|
||||
| Question | Default | Promote or extract only when |
|
||||
| --- | --- | --- |
|
||||
| Where should code live? | Keep it local to the feature workflow, route, or owner. | Multiple verticals need the same stable primitive. |
|
||||
| Who owns state, data, and handlers? | The lowest component that uses them. | A parent coordinates shared loading, errors, empty UI, selection, submission, navigation, or one consistent snapshot. |
|
||||
| Should this become Jotai state? | Keep synchronous UI/form state in component or DOM state. | Siblings need one source of truth, the value drives atoms, or scoped workflow state must survive hidden/unmounted steps. |
|
||||
| Should URL state enter Jotai? | Let Next.js route params and `nuqs` own URL state and updates. | Query atoms or shared derived atoms need a read-only bridge hydrated at the route/surface boundary. |
|
||||
| Should this query/mutation become an atom? | Use TanStack Query hooks at the lowest owner. | It reads atom state, feeds derived atoms, or participates in shared Jotai workflow orchestration. |
|
||||
| Should this be a helper/wrapper? | Prefer direct readable code at the use site. | The name captures a stable domain rule or the wrapper owns real behavior, validation, state, error handling, or semantics. |
|
||||
| Is an Effect needed? | No. Derive during render or handle the user action in the event handler. | It synchronizes with an external system such as browser APIs, subscriptions, timers, analytics, or imperative DOM/non-React widgets. |
|
||||
|
||||
## Core Defaults
|
||||
|
||||
- Search before adding UI, hooks, helpers, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit.
|
||||
- Group code by feature workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them.
|
||||
- Promote code to shared only when multiple verticals need the same stable primitive. Otherwise keep it local and compose shared primitives inside the owning feature.
|
||||
- Prefer local code and purpose-named helpers over catch-all utility modules; do not group workflow-specific defaults, validation, payload shaping, or metadata merging in a generic utils file just because they share a DTO.
|
||||
- Keep source/default selection, validation, and payload shaping close to the workflow that owns the behavior. Do not extract a shared helper just because two flows read the same DTO when their priority order, fallback behavior, or submit semantics differ.
|
||||
- Prefer direct, readable conditionals at the use site for small branch-specific decisions, especially form source selection and request payload assembly. Extract only when the helper name captures a stable domain rule and removes repeated complexity without hiding flow-specific behavior.
|
||||
- When fixing an invalid pattern, scan the touched feature or branch for equivalent patterns and fix them together.
|
||||
- Follow Dify's CSS-first Tailwind v4 contract from `packages/dify-ui/README.md` and `packages/dify-ui/AGENTS.md`. Prefer design-system tokens, utilities, and radius mappings over generic Tailwind guidance.
|
||||
- Search before adding UI, hooks, helpers, query utilities, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit.
|
||||
- Follow Dify's CSS-first Tailwind v4 contract from `packages/dify-ui/README.md` and `packages/dify-ui/AGENTS.md`. Prefer design-system tokens, utilities, and radius mappings over generic Tailwind choices.
|
||||
- Group feature code by workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them.
|
||||
- Keep source/default selection, validation, dirty checks, and payload shaping close to the workflow that owns submit behavior. Do not hide flow-specific priority order, fallback behavior, or submit semantics in generic utilities.
|
||||
- Prefer direct conditionals for small branch-specific decisions, especially form source selection and request payload assembly.
|
||||
- Loading states for page sections, cards, lists, tables, forms, and drawers should be skeletons scoped to the content being loaded. Use spinners only for small inline busy indicators.
|
||||
|
||||
## Feature Workflow Layout
|
||||
## Layout And Ownership
|
||||
|
||||
- State-heavy wizards, drawers, modals, and secondary workflows work best as a small feature surface with route/entry files, a single feature-local state file, and feature-local UI.
|
||||
- Keep `ui/` shallow with owner files that map to the workflow's real composition boundaries and major visual regions.
|
||||
- Owner files contain the section components, field components, skeletons, and one-off helper components that belong to their visual region.
|
||||
- Folders represent groups of related files with a shared owner and a stable reason to change together.
|
||||
- The entry file handles route integration, provider wiring, close behavior, and feature surface mounting. The composition owner handles high-level workflow branching, and the closest visual owner handles section branching.
|
||||
- State-heavy wizards, drawers, modals, and secondary workflows can be a small feature surface: an entry file, one feature-local state file when Jotai is actually needed, and shallow `ui/` owners that match real visual regions.
|
||||
- The entry file handles route integration, provider wiring, close behavior, and surface mounting. The composition owner handles high-level workflow branching. The closest visual owner handles section branching.
|
||||
- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data; TanStack Query deduplicates and shares cache.
|
||||
- Pass stable domain identity across boundaries. Do not forward derived presentation state when the receiver can derive it from its own data source.
|
||||
- A component that owns a visual surface should also own data access, loading, empty, and error states for content rendered inside it unless a parent truly coordinates that state.
|
||||
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in Query and API flow.
|
||||
- Do not replace prop drilling with one large view-model hook threaded through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it.
|
||||
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child, menu, or row own the action.
|
||||
|
||||
## Ownership
|
||||
## Feature-Scoped Jotai
|
||||
|
||||
- Put local state, queries, mutations, handlers, and derived UI data in the lowest component that uses them. Extract a purpose-built owner component only when the logic has no natural home.
|
||||
- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data. Do not hoist a query only because it is duplicated; TanStack Query handles deduplication and cache sharing.
|
||||
- Hoist state, queries, or callbacks to a parent only when the parent consumes the data, coordinates shared loading/error/empty UI, needs one consistent snapshot, or owns a workflow spanning children.
|
||||
- Pass stable domain identity across boundaries; avoid forwarding derived presentation state when the receiver can derive it from its own data source. A component that owns a visual surface should also own the data access, loading, empty, and error states for content rendered inside it unless a parent truly coordinates that state.
|
||||
- Loading states for visual surfaces should use skeleton placeholders scoped to the content that is actually loading, with shape, density, and dimensions close to the final UI. Avoid generic loading text or centered spinners for page sections, cards, lists, tables, forms, and drawers; reserve spinners for small inline busy indicators such as an in-progress status icon.
|
||||
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow.
|
||||
- Do not replace prop drilling with one top-level hook that returns a large view model and then thread that object through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it, or use feature-scoped Jotai atoms for simple shared form/UI state when siblings need the same source of truth.
|
||||
- When using feature-scoped Jotai state for a form, drawer, or other secondary surface, scope the store to that surface instance when stale cross-instance state is possible. Initialize stable config at the owning boundary, then let descendants read only the atoms or purpose-named hooks they actually need.
|
||||
- For Jotai-backed surfaces, put shared query atoms, mutation atoms, derived state, and write actions in the feature state file when they coordinate multiple descendants. The lowest-owner rule still applies to independent visual surfaces that do not participate in shared state.
|
||||
- For repeated row/menu action surfaces that need reset, hydrate the stable identity at the surface entry and scope only the primitives that truly need per-instance reset, such as open flags, drafts, or selected local options.
|
||||
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action.
|
||||
- Prefer uncontrolled DOM state and CSS variables before adding controlled props.
|
||||
|
||||
## Feature-Scoped Jotai State
|
||||
|
||||
- A module's feature-local state lives in one state file for Jotai-backed features: primitive atoms, query atoms, derived atoms, write-only action atoms, mutation atoms, submission orchestration, provider exports, and optional scope configuration.
|
||||
- Keep state local when one component owns it, even inside Jotai-backed features. Dialog open flags, menu/popover visibility, confirmation visibility, form/input drafts, row-local pending flags, and in-flight refs usually belong in component state.
|
||||
- Promote UI state to an atom only when siblings need the same source of truth, the value drives a query or mutation atom, a parent workflow coordinates the state, or the state intentionally persists across hidden or unmounted descendants within a scoped surface.
|
||||
- Reflect atom-backed surface-wide locks or invariants in every affected trigger. If only one row, menu, or dialog should be disabled, keep the pending or lock state local to that row, menu, or dialog.
|
||||
- Atom order in the state file follows the dependency graph: types/constants, editable primitives, query atoms, query-data derived atoms, readiness/business derived atoms, write actions, mutation atoms, submission orchestration, provider exports.
|
||||
- Derived atom names read as business facts. Write atom names read as user or workflow commands.
|
||||
- UI components read and write the exact atom they use with `useAtomValue` or `useSetAtom`. Repeated workflow semantics live in named derived atoms or write atoms.
|
||||
- Non-query derived atoms return a narrow value with a clear domain name; avoid pass-through aliases or bundling unrelated UI facts. Query atoms expose the TanStack Query result object so loading, error, fetch, and pagination state stay attached to the query contract.
|
||||
- Write-only atoms own synchronous state transitions that update multiple primitives, reset dependent state, or advance the workflow. Async work with loading, error, caching, retry, or stale-result concerns should be modeled as query or mutation atoms, with write atoms only changing the inputs that drive them.
|
||||
- Avoid feature hooks that aggregate form values, query results, derived state, and commands for sibling components. Prefer named derived atoms and write atoms so UI components read the exact shared fact or command they need.
|
||||
- When a form library owns validation, keep submit orchestration in feature state when post-submit result or error state is shared by the surface. Avoid duplicating validation gates or request shaping in UI hooks.
|
||||
- `jotai-tanstack-query` atoms use the same QueryClient as the React Query provider. Query atoms belong in feature state when atoms are the feature's local state surface.
|
||||
- Jotai scope is an optional instance-isolation tool for secondary surfaces with independent local state. Query and mutation atoms keep shared cache behavior through the shared QueryClient.
|
||||
- Do not put `atomWithQuery`, `atomWithInfiniteQuery`, `atomWithMutation`, or broad derived orchestration atoms in a `ScopeProvider` just to reset a surface. Scoped derived atoms implicitly scope their dependencies, which can duplicate query client access and break shared invalidation. Leave query/mutation atoms unscoped; let them read scoped primitive inputs.
|
||||
- Scope providers should list resettable primitive atoms and explicit hydration tuples. If a derived atom must be scoped, confirm that every dependency it implicitly scopes is meant to be private to that surface.
|
||||
- Keep independent dialog lifecycles separate. Avoid a single discriminated "current action dialog" atom when edit, delete, and other dialogs have their own open state, loading guard, or reset behavior.
|
||||
- Route-derived stable identities that do not need instance reset or scoped isolation can be hydrated at the route or layout boundary into a feature route atom. Use scoped atoms only when stale cross-instance state or per-surface reset semantics are needed.
|
||||
- A Jotai-backed feature has one feature-local state file for shared primitive atoms, query atoms, derived atoms, write-only actions, mutation atoms, submission orchestration, provider exports, and optional scope configuration.
|
||||
- Keep component-owned synchronous UI state local even inside Jotai features: dialog open flags, menus/popovers, confirmations, field drafts, and selected local options usually belong in component state.
|
||||
- Use uncontrolled `@langgenius/dify-ui/form` and `@langgenius/dify-ui/field` controls for edit/create forms whose fields are read only at submit time. Initialize query-backed defaults with `defaultValue` and keyed remounts.
|
||||
- Promote form state to atoms only when another component must react to in-progress values, a draft must survive unmount/remount in the scoped workflow, or multiple steps share the same editable draft before submit.
|
||||
- Treat `useParams`, route args, and `nuqs` query state as framework-owned state. When atom logic needs those values, hydrate primitive atoms at the route or surface boundary, such as with `useHydrateAtoms(..., { dangerouslyForceHydrate: true })`; keep URL updates in the route/query-state APIs instead of write atoms.
|
||||
- For async work tied to atom state, use `atomWithQuery` or `atomWithMutation`; write atoms should update only the inputs that drive those atoms. This applies to pure frontend async work as well as network requests, so do not hand-roll loading/error/in-flight state with `useState` or `useRef` for atom-orchestrated async behavior. For component-owned remote work, use `useQuery` or `useMutation` directly.
|
||||
- Row-local async state belongs to the row owner unless it participates in a shared Jotai workflow or needs atom-scoped reset semantics.
|
||||
- Leave query and mutation atoms unscoped so they keep shared QueryClient cache and invalidation behavior. Scope resettable primitives and explicit hydration tuples; scope a derived atom only when every dependency should be private to that surface.
|
||||
- For scoped primitives that are always hydrated by `ScopeProvider`, prefer `atomWithLazy<T>(() => { throw new Error(...) })` when consumers should see a non-null type.
|
||||
- Order state files by dependency graph: types/constants, primitives, query atoms, query-data derived atoms, business/readiness derived atoms, write actions, mutation atoms, submission orchestration, provider exports.
|
||||
- Name derived atoms as business facts and write atoms as user or workflow commands. Components should read or write the exact atom they need with `useAtomValue` or `useSetAtom`.
|
||||
- Menu/dialog `open` state usually stays local, but a scoped atom is acceptable when a composed menu plus secondary surface would otherwise pass confusing `open`/`onClose` props through unrelated layers. Scope that primitive with the surface instance so reset behavior stays local.
|
||||
- Keep independent dialog lifecycles separate. Avoid one discriminated "current action dialog" atom when dialogs have separate open state, loading guards, or reset behavior.
|
||||
|
||||
## Components, Props, And Types
|
||||
|
||||
@ -68,80 +62,56 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Prefer named exports. Use default exports only where the framework requires them, such as Next.js route files.
|
||||
- Type simple one-off props inline. Use a named `Props` type only when reused, exported, complex, or clearer.
|
||||
- Use API-generated or API-returned types at component boundaries. Keep small UI conversion helpers and one-off UI extensions beside the component that needs them.
|
||||
- Do not create type aliases that only rename another type. Use an alias only when it encodes a real UI concept, refinement, or reusable local contract.
|
||||
- Name values by their domain role and backend API contract, and keep that name stable across the call chain, especially persistent IDs and route params. Normalize framework or route params at the boundary.
|
||||
- Keep fallback and invariant checks at the lowest component that already handles that state; avoid defensive fallbacks that mask impossible states.
|
||||
- Do not extract fallback helpers whose only behavior is hiding missing display data. The component that renders the surface owns the empty, disabled, hidden, or placeholder state.
|
||||
- Do not create type aliases that only rename another type. Use aliases only for real UI concepts, refinements, or reusable local contracts.
|
||||
- Name values by their domain role and backend API contract, especially persistent IDs and route params. Normalize framework or route params at the boundary.
|
||||
- Put fallback and invariant checks in the lowest component that already handles that state. Do not extract helpers whose only behavior is hiding missing display data.
|
||||
|
||||
## Generated API Contracts
|
||||
## Generated API And Nullable Data
|
||||
|
||||
- Treat generated contracts as authoritative at API, query, mutation, cache, and service boundaries. For enterprise APIs, use `packages/contracts/generated/enterprise/*`.
|
||||
- Do not hand-write local request/response/reply/page/cache-data types that mirror generated DTOs. Import or infer the generated type.
|
||||
- Do not widen generated fields or enums for compatibility. Normalize legacy input at the boundary, then return the generated field type.
|
||||
- Do not repair generated or API-returned contract fields in components unless the API contract or product requirement says they need normalization. Treat enums, statuses, and presence flags as exact contract values.
|
||||
- Use generated enum objects and union types directly in props, comparisons, status logic, and i18n keys. Do not add local enum constants or parallel frontend enum/status layers unless they model real product state not represented by the API. Presentation-only tone maps should be keyed by the generated enum.
|
||||
- Normalize or coerce only at a real boundary, such as user-entered forms, search, URL/query params, file names, DOM IDs, or legacy adapters. Preserve user-entered values when whitespace or formatting can be meaningful.
|
||||
- Do not coerce nullable or optional API strings to `''` in query, derived model, or payload-building code. Keep `undefined` or `null` until the final boundary that requires a string.
|
||||
- Do not use `value || undefined` for mutation payload fields where an empty string means "clear this value". Trim or normalize at the form boundary, then preserve `''` when the API contract treats it as an intentional update.
|
||||
- Local UI models are fine for presentation, form state, select options, or guarded required-field refinements. Name them as UI concepts, not generated DTO mirrors.
|
||||
- Required-value refinements are allowed only after same-branch filtering or early return. Prefer nullable-tolerant props for render-only data.
|
||||
- When a component needs a stricter shape than a generated DTO, refine once at the API/query-to-UI boundary into a purpose-named UI type instead of hiding missing fields with generic fallback or coercion helpers.
|
||||
|
||||
## Nullable API Data
|
||||
|
||||
- Prefer nullable-tolerant call boundaries. Pass API-returned types through for render-only rows, and let the component render fallback, disabled state, or nothing.
|
||||
- Narrow only where a real value is required, such as mutation params, route hrefs, select values, or query input. Build that target model with `flatMap`, a local loop, or an early return so the required value is captured in the same branch.
|
||||
- If design says a field is the display value, use that field. Only the final component should decide whether a nullable display value renders a placeholder, hides content, or disables an action.
|
||||
- Do not wrap required arrays or fields in null-fallback helpers. Use empty collection fallbacks only for not-yet-loaded query data or genuinely nullable collections at the owning render boundary.
|
||||
- Do not drop rows only to satisfy props or React keys; use a stable fallback key when possible.
|
||||
- Use conditional spreads or explicit pushes for conditional array items instead of `undefined` placeholders followed by a narrowing filter.
|
||||
- Avoid truthiness type guards, `filter(Boolean)`, `filter(item => item.id)`, and `!` after those filters.
|
||||
- Use type guards only for meaningful domain or runtime validation, such as enum membership, object shape, or a reusable business invariant.
|
||||
- Do not hand-write DTO mirrors, widen generated fields/enums, or add parallel frontend enum/status layers unless they model product state not represented by the API.
|
||||
- Use generated enum objects and union types directly in props, comparisons, status logic, and i18n keys. Presentation-only tone maps should be keyed by generated enums.
|
||||
- Normalize or coerce only at real boundaries: user-entered forms, search, URL/query params, file names, DOM IDs, or legacy adapters.
|
||||
- Do not coerce nullable or optional API strings to `''` in query, derived model, or payload-building code. Keep `null` or `undefined` until the final boundary requiring a string.
|
||||
- Do not use `value || undefined` for mutation fields where `''` means "clear this value". Trim or normalize at the form boundary, then preserve intentional empty strings.
|
||||
- Prefer nullable-tolerant render props for API-returned rows. Narrow only where a real value is required, such as mutation params, route hrefs, select values, query input, or required React keys.
|
||||
- Build required values in the same branch that proves them, using `flatMap`, a local loop, or an early return. Avoid truthiness guards, `filter(Boolean)`, `filter(item => item.id)`, and `!` after filters.
|
||||
- Use conditional spreads or explicit pushes for conditional array items instead of `undefined` placeholders followed by narrowing filters.
|
||||
- Empty collection fallbacks are for not-yet-loaded query data or genuinely nullable collections at the owning render boundary, not for hiding required API fields.
|
||||
|
||||
## Queries And Mutations
|
||||
|
||||
- Keep `web/contract/*` as the single source of truth for API shape; follow existing domain/router patterns and the `{ params, query?, body? }` input shape.
|
||||
- Consume queries directly with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`.
|
||||
- In `atomWithQuery` and `atomWithInfiniteQuery`, return generated `queryOptions()` or `infiniteOptions()` directly. Pass `enabled`, `retry`, `placeholderData`, `select`, and pagination options into that call instead of spreading generated options into a hand-built object.
|
||||
- In `atomWithMutation`, return generated `mutationOptions()` directly when using generated clients. Put request shaping and submit orchestration in write atoms; do not rebuild mutation option objects just to pass through the generated mutation function.
|
||||
- For custom query functions that do not come from generated clients, wrap the options object with TanStack `queryOptions(...)` so query atoms still return a query options contract.
|
||||
- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename `queryOptions()` or `mutationOptions()`. Extract a small `queryOptions` helper only when repeated call-site options justify it.
|
||||
- Keep feature hooks for real orchestration, workflow state, or shared domain behavior.
|
||||
- For TanStack cache data, use generated or query-derived types; do not create local wrappers for `getQueryData` or `getQueriesData`.
|
||||
- For generated oRPC `queryOptions()` / `infiniteOptions()`, keep returning the generated options directly. When required input is missing, use a whole-input branch such as `input: condition ? validInput : skipToken` together with `enabled: Boolean(condition)` so no request runs and no fake payload is built.
|
||||
- Do not put `skipToken` inside a nested placeholder payload, such as `{ params: { appInstanceId: skipToken } }`. Do not create hand-written "missing queryOptions" objects or coerce required IDs to `''`.
|
||||
- Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))`; use oRPC clients as `mutationFn` only for custom flows.
|
||||
- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`; components may add UI feedback callbacks, but should not own shared invalidation rules.
|
||||
- Component or atom mutation callbacks can handle local UI feedback such as toasts, closing dialogs, or navigation. They should not replace shared invalidation or add local cache patches for shared server state.
|
||||
- Keep `web/contract/*` as the API shape source of truth and follow the `{ params, query?, body? }` input shape.
|
||||
- Consume generated queries with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`.
|
||||
- Consume owner-local mutations with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))` when pending/error state is not consumed by feature atoms.
|
||||
- In `atomWithQuery`, `atomWithInfiniteQuery`, and `atomWithMutation`, return generated `queryOptions()`, `infiniteOptions()`, or `mutationOptions()` directly. Pass `enabled`, `retry`, `placeholderData`, `select`, and pagination options into the generated call instead of spreading options into a hand-built object.
|
||||
- For generated oRPC options with missing required input, branch the whole input with `input: condition ? validInput : skipToken` and `enabled: Boolean(condition)`. Never place `skipToken` inside a nested placeholder payload or coerce required IDs to `''`.
|
||||
- When prefetch and render use the same request, extract local query options or a query-options atom so `prefetchQuery` and `useQuery`/`atomWithQuery` share the exact options.
|
||||
- For custom query or mutation functions, wrap options with TanStack `queryOptions(...)` or `mutationOptions(...)`.
|
||||
- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename generated options. Keep feature hooks for real orchestration, workflow state, or shared domain behavior.
|
||||
- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`. Component or atom callbacks may handle local toasts, closing dialogs, and navigation, but should not replace shared invalidation or patch shared server state locally.
|
||||
- For overlays that may open heavier secondary content, prefetch from the trigger/menu open event with `queryClient.prefetchQuery(queryOptions)` when `onOpenChange` is available. Do not mount hidden subscribers just to warm cache.
|
||||
- Do not use deprecated `useInvalid` or `useReset`.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required, and wrap awaited calls in `try/catch`.
|
||||
|
||||
## Component Boundaries
|
||||
## Boundaries And Overlays
|
||||
|
||||
- Use the first level below a page or tab to organize independent page sections when it adds real structure. This layer is layout/semantic first, not automatically the data owner.
|
||||
- Treat component names, semantic roles, and user- or design-marked visual regions as boundary constraints. Do not expand a child component's responsibility just because its data is useful nearby; keep adjacent UI as a sibling owner or introduce a correctly named broader owner.
|
||||
- Split deeper components by the data and state each layer actually needs. Each component should access only necessary data, and ownership should stay at the lowest consumer.
|
||||
- Use the first level below a page or tab to organize independent page sections when it adds structure. This layer is layout/semantic first, not automatically the data owner.
|
||||
- Treat component names, semantic roles, and user- or design-marked visual regions as boundary constraints. Keep adjacent UI as a sibling owner or introduce a correctly named broader owner.
|
||||
- Keep cohesive forms, menu bodies, and one-off helpers local unless they need their own state, reuse, or semantic boundary.
|
||||
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow.
|
||||
- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment.
|
||||
- When a dialog, dropdown, or popover component already accepts controlled `open` state, mount the surface unconditionally unless unmounting is required for performance or reset semantics. Use keyed scope or local state reset for reset behavior instead of `{open && <Surface />}` wrappers.
|
||||
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
|
||||
- Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, children-as-pass-through composition, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook, forwards props, or passes trigger/content through to one child, move the logic into that child or make the wrapper own a real surface.
|
||||
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component when hidden content would obscure the parent.
|
||||
- Preserve composability by separating behavior ownership from placement ownership: an action can own trigger/open/menu content while the caller owns slots, offsets, and alignment.
|
||||
- When a dialog, dropdown, or popover accepts controlled `open`, mount it unconditionally unless unmounting is required for performance or reset semantics. Use keyed scope or local state reset instead of `{open && <Surface />}` wrappers.
|
||||
- When opening a dialog from a menu item, keep the menu and dialog as sibling surfaces. Let the menu command open the dialog, and mount the dialog outside menu popup content.
|
||||
- For dialogs and alert dialogs, keep the root responsible for `open` wiring and put query/mutation hooks inside the content component when work should mount only after the overlay opens.
|
||||
- Prefer uncontrolled overlay roots when the library can own open state. Use `onOpenChange` for side effects and CSS/data selectors for open-state styling.
|
||||
- Avoid wrapper DOM unless it provides layout, semantics, accessibility, state ownership, or library integration. Avoid shallow wrappers, hook-to-props adapters, layout-only render props, children pass-through wrappers, and prop renaming unless they add real behavior or a real boundary.
|
||||
|
||||
## You Might Not Need An Effect
|
||||
|
||||
- Use Effects only to synchronize with external systems such as browser APIs, non-React widgets, subscriptions, timers, analytics that must run because the component was shown, or imperative DOM integration.
|
||||
- Do not use Effects to transform props or state for rendering. Calculate derived values during render, and use `useMemo` only when the calculation is actually expensive.
|
||||
- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known.
|
||||
- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render.
|
||||
- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary.
|
||||
- For forms initialized from query data, prefer keyed remounts or surface-entry hydration of form/field atoms over an Effect that copies query data into form state.
|
||||
- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components.
|
||||
- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow.
|
||||
|
||||
## Navigation And Performance
|
||||
## Effects, Navigation, And Performance
|
||||
|
||||
- Use Effects only to synchronize with external systems. Do not use Effects to transform props/state for rendering, handle user actions, copy state, reset state from props, or fetch data.
|
||||
- For forms initialized from query data, prefer keyed remounts or surface-entry atom hydration over Effects that copy query data into form state.
|
||||
- Prefer framework data APIs or TanStack Query for data fetching.
|
||||
- Prefer `Link` for normal navigation. Use router APIs only for command-flow side effects such as mutation success, guarded redirects, or form submission.
|
||||
- Before reaching for `memo`, first try moving changing state down to the smallest component that actually uses it so unrelated sibling trees stay untouched.
|
||||
- If changing state must wrap other content, lift the unchanged content up and pass it as `children` so the stateful wrapper can update without React visiting that subtree.
|
||||
- Before using `memo`, move changing state down to the smallest component that uses it. If state must wrap stable content, lift the stable content up and pass it as `children`.
|
||||
- Avoid `memo`, `useMemo`, and `useCallback` unless there is a clear performance reason.
|
||||
|
||||
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
@ -105,6 +105,14 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: vp run knip
|
||||
|
||||
- name: Web dead code check production
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: vp run knip:production
|
||||
|
||||
- name: Web production unused declarations check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: vp run knip:production-unused-check
|
||||
|
||||
ts-common-style:
|
||||
name: TS Common
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
@ -116,7 +116,7 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
||||
## Using Dify
|
||||
|
||||
- **Cloud <br/>**
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. If you run into issues with Dify Cloud, [contact our Cloud support team](mailto:cloud@dify.ai?subject=%5BGitHub%5DDify%20Cloud%20Support).
|
||||
|
||||
- **Self-hosting Dify Community Edition<br/>**
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
|
||||
@ -36,6 +36,9 @@ FILES_ACCESS_TIMEOUT=300
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=true
|
||||
|
||||
# Learn app feature toggle
|
||||
ENABLE_LEARN_APP=true
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ def reset_password(email, new_password, password_confirm):
|
||||
return
|
||||
normalized_email = email.strip().lower()
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email.strip())
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
@ -67,7 +67,7 @@ def reset_email(email, new_email, email_confirm):
|
||||
return
|
||||
normalized_new_email = new_email.strip().lower()
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email.strip())
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email.strip())
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
@ -133,8 +133,9 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
password=new_password,
|
||||
language=language,
|
||||
create_workspace_required=False,
|
||||
session=db.session,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name, session=db.session)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
|
||||
@ -1073,6 +1073,12 @@ class MailConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class HomepageConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for homepage feature toggles exposed through system features.
|
||||
"""
|
||||
|
||||
ENABLE_TRIAL_APP: bool = Field(
|
||||
description="Enable trial app",
|
||||
default=False,
|
||||
@ -1083,6 +1089,11 @@ class MailConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENABLE_LEARN_APP: bool = Field(
|
||||
description="Enable Learn App",
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
@ -1489,6 +1500,7 @@ class FeatureConfig(
|
||||
EndpointConfig,
|
||||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HomepageConfig,
|
||||
HttpConfig,
|
||||
InnerAPIConfig,
|
||||
IndexingConfig,
|
||||
|
||||
@ -28,9 +28,9 @@ from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.entities.agent_entities import ComposerSavePayload
|
||||
from services.entities.agent_entities import ComposerSavePayload, WorkflowComposerCopyFromRosterPayload
|
||||
|
||||
register_schema_models(console_ns, ComposerSavePayload)
|
||||
register_schema_models(console_ns, ComposerSavePayload, WorkflowComposerCopyFromRosterPayload)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AgentAppComposerResponse,
|
||||
@ -91,6 +91,38 @@ class WorkflowAgentComposerApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/copy-from-roster")
|
||||
class WorkflowAgentComposerCopyFromRosterApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowComposerCopyFromRosterPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow roster agent copied to inline agent",
|
||||
console_ns.models[WorkflowAgentComposerResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_EDIT)
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = WorkflowComposerCopyFromRosterPayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.copy_workflow_composer_from_roster(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account_id,
|
||||
source_agent_id=payload.source_agent_id,
|
||||
source_snapshot_id=payload.source_snapshot_id,
|
||||
idempotency_key=payload.idempotency_key,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/validate")
|
||||
class WorkflowAgentComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
|
||||
@ -91,7 +91,10 @@ class AgentAppFeatureConfigResource(Resource):
|
||||
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
new_app_model_config = AgentAppFeatureConfigService.update_features(
|
||||
app_model=app_model, account=current_user, config=args.model_dump(exclude_none=True), session=db.session
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
config=args.model_dump(exclude_none=True),
|
||||
session=db.session,
|
||||
)
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
@ -19,6 +19,7 @@ from controllers.console.wraps import (
|
||||
rbac_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
Annotation,
|
||||
@ -388,7 +389,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
update_args["answer"] = args.answer
|
||||
if args.question is not None:
|
||||
update_args["question"] = args.question
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, str(app_id), str(annotation_id))
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
update_args, str(app_id), str(annotation_id), db.session
|
||||
)
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@ -398,7 +401,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_EDIT)
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
def delete(self, app_id: UUID, annotation_id: UUID):
|
||||
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
|
||||
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id), db.session)
|
||||
return "", 204
|
||||
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
@ -142,6 +143,7 @@ class ChatMessageTextApi(Resource):
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
session=db.session,
|
||||
text=payload.text,
|
||||
voice=payload.voice,
|
||||
message_id=payload.message_id,
|
||||
|
||||
@ -341,8 +341,8 @@ class MessageFeedbackExportApi(Resource):
|
||||
|
||||
try:
|
||||
export_data = FeedbackService.export_feedbacks(
|
||||
db.session(),
|
||||
app_id=app_model.id,
|
||||
app_model.id,
|
||||
session=db.session(),
|
||||
from_source=args.from_source,
|
||||
rating=args.rating,
|
||||
has_comment=args.has_comment,
|
||||
|
||||
@ -17,6 +17,7 @@ from controllers.console.wraps import (
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import AccountWithRole
|
||||
from libs.helper import build_avatar_url, dump_response, to_timestamp
|
||||
@ -489,7 +490,7 @@ class WorkflowCommentMentionUsersApi(Resource):
|
||||
current_tenant = current_user.current_tenant # need the tenant object here
|
||||
if current_tenant is None:
|
||||
raise ValueError("current tenant is required")
|
||||
members = TenantService.get_tenant_members(current_tenant)
|
||||
members = TenantService.get_tenant_members(current_tenant, session=db.session)
|
||||
users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersPayload(users=users)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@ -89,7 +89,9 @@ class ActivateCheckApi(Resource):
|
||||
workspaceId = args.workspace_id
|
||||
token = args.token
|
||||
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(
|
||||
workspaceId, args.email, token, session=db.session
|
||||
)
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
tenant = invitation.get("tenant", None)
|
||||
@ -137,7 +139,9 @@ class ActivateApi(Resource):
|
||||
args = ActivatePayload.model_validate(console_ns.payload)
|
||||
|
||||
normalized_request_email = args.email.lower() if args.email else None
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
|
||||
invitation = RegisterService.get_invitation_with_case_fallback(
|
||||
args.workspace_id, args.email, args.token, session=db.session
|
||||
)
|
||||
if invitation is None:
|
||||
raise AlreadyActivateError()
|
||||
|
||||
@ -184,6 +188,6 @@ class ActivateApi(Resource):
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
TenantService.switch_tenant(account, tenant.id, session=db.session)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@ -58,7 +59,7 @@ class ApiKeyAuthDataSource(Resource):
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(db.session(), current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
"sources": [
|
||||
@ -92,7 +93,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
|
||||
ApiKeyAuthService.create_provider_auth(db.session(), current_tenant_id, data)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
@ -109,6 +110,6 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, binding_id: UUID):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
|
||||
ApiKeyAuthService.delete_provider_auth(db.session(), current_tenant_id, str(binding_id))
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.console.auth.error import (
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
@ -100,7 +101,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, args.email)
|
||||
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@ -175,7 +176,7 @@ class EmailRegisterResetApi(Resource):
|
||||
email = register_data.get("email", "")
|
||||
normalized_email = email.lower()
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
@ -186,7 +187,7 @@ class EmailRegisterResetApi(Resource):
|
||||
timezone=args.timezone,
|
||||
language=args.language,
|
||||
)
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
@ -205,6 +206,7 @@ class EmailRegisterResetApi(Resource):
|
||||
password=password,
|
||||
interface_language=get_valid_language(language),
|
||||
timezone=timezone,
|
||||
session=db.session,
|
||||
)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
@ -82,7 +82,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, args.email)
|
||||
|
||||
token = AccountService.send_reset_password_email(
|
||||
account=account,
|
||||
@ -180,7 +180,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(args.new_password, salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
|
||||
if account:
|
||||
account = db.session.merge(account)
|
||||
@ -198,10 +198,10 @@ class ForgotPasswordResetApi(Resource):
|
||||
|
||||
# Create workspace if needed
|
||||
if (
|
||||
not TenantService.get_join_tenants(account)
|
||||
not TenantService.get_join_tenants(account, session=db.session)
|
||||
and FeatureService.get_system_features().is_allow_create_workspace
|
||||
):
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
|
||||
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
@ -119,7 +119,9 @@ class LoginApi(Resource):
|
||||
invite_token = args.invite_token
|
||||
invitation_data: InvitationDetailDict | None = None
|
||||
if invite_token:
|
||||
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
|
||||
invitation_data = RegisterService.get_invitation_with_case_fallback(
|
||||
None, request_email, invite_token, session=db.session
|
||||
)
|
||||
if invitation_data is None:
|
||||
invite_token = None
|
||||
|
||||
@ -145,7 +147,7 @@ class LoginApi(Resource):
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
tenants = TenantService.get_join_tenants(account, session=db.session)
|
||||
if len(tenants) == 0:
|
||||
system_features = FeatureService.get_system_features()
|
||||
|
||||
@ -157,7 +159,7 @@ class LoginApi(Resource):
|
||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||
}
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
token_pair = AccountService.login(account=account, session=db.session, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
@ -291,7 +293,7 @@ class EmailCodeLoginApi(Resource):
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
tenants = TenantService.get_join_tenants(account, session=db.session)
|
||||
if not tenants:
|
||||
workspaces = FeatureService.get_system_features().license.workspaces
|
||||
if not workspaces.is_available():
|
||||
@ -299,7 +301,7 @@ class EmailCodeLoginApi(Resource):
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
else:
|
||||
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
|
||||
TenantService.create_tenant_member(new_tenant, account, db.session, role="owner")
|
||||
account.current_tenant = new_tenant
|
||||
tenant_was_created.send(new_tenant)
|
||||
@ -311,6 +313,7 @@ class EmailCodeLoginApi(Resource):
|
||||
name=user_email,
|
||||
interface_language=get_valid_language(language),
|
||||
timezone=args.timezone,
|
||||
session=db.session,
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
@ -319,7 +322,7 @@ class EmailCodeLoginApi(Resource):
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
token_pair = AccountService.login(account, session=db.session, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(user_email)
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
@ -343,7 +346,7 @@ class RefreshTokenApi(Resource):
|
||||
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||
|
||||
try:
|
||||
new_token_pair = AccountService.refresh_token(refresh_token)
|
||||
new_token_pair = AccountService.refresh_token(refresh_token, session=db.session)
|
||||
|
||||
# Create response with new cookies
|
||||
response = make_response({"result": "success"})
|
||||
@ -358,22 +361,22 @@ class RefreshTokenApi(Resource):
|
||||
|
||||
|
||||
def _get_account_with_case_fallback(email: str):
|
||||
account = AccountService.get_user_through_email(email)
|
||||
account = AccountService.get_user_through_email(email, session=db.session)
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return AccountService.get_user_through_email(email.lower())
|
||||
return AccountService.get_user_through_email(email.lower(), session=db.session)
|
||||
|
||||
|
||||
def _authenticate_account_with_case_fallback(
|
||||
original_email: str, normalized_email: str, password: str, invite_token: str | None
|
||||
):
|
||||
try:
|
||||
return AccountService.authenticate(original_email, password, invite_token)
|
||||
return AccountService.authenticate(original_email, password, invite_token, session=db.session)
|
||||
except services.errors.account.AccountPasswordError:
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
return AccountService.authenticate(normalized_email, password, invite_token, session=db.session)
|
||||
|
||||
|
||||
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
|
||||
@ -195,7 +195,7 @@ class OAuthCallback(Resource):
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, session=db.session)
|
||||
except Unauthorized:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
@ -206,6 +206,7 @@ class OAuthCallback(Resource):
|
||||
|
||||
token_pair = AccountService.login(
|
||||
account=account,
|
||||
session=db.session,
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@ -224,7 +225,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Account | None = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(user_info.email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, user_info.email)
|
||||
|
||||
return account
|
||||
|
||||
@ -240,12 +241,12 @@ def _generate_account(
|
||||
oauth_new_user = False
|
||||
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
tenants = TenantService.get_join_tenants(account, session=db.session)
|
||||
if not tenants:
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
raise WorkSpaceNotAllowedCreateError()
|
||||
else:
|
||||
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=db.session)
|
||||
TenantService.create_tenant_member(new_tenant, account, db.session, role="owner")
|
||||
account.current_tenant = new_tenant
|
||||
tenant_was_created.send(new_tenant)
|
||||
@ -272,9 +273,10 @@ def _generate_account(
|
||||
provider=provider,
|
||||
language=interface_language,
|
||||
timezone=timezone,
|
||||
session=db.session,
|
||||
)
|
||||
|
||||
# Link account
|
||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
AccountService.link_account_integrate(provider, user_info.id, account, session=db.session)
|
||||
|
||||
return account, oauth_new_user
|
||||
|
||||
@ -17,6 +17,7 @@ from controllers.console.wraps import (
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
@ -65,7 +66,9 @@ class DatasetMetadataCreateApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args, current_user, current_tenant_id)
|
||||
metadata = MetadataService.create_metadata(
|
||||
db.session(), dataset_id_str, metadata_args, current_user, current_tenant_id
|
||||
)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 201
|
||||
|
||||
@setup_required
|
||||
@ -81,7 +84,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
metadata = MetadataService.get_dataset_metadatas(dataset)
|
||||
metadata = MetadataService.get_dataset_metadatas(db.session(), dataset)
|
||||
return dump_response(DatasetMetadataListResponse, metadata), 200
|
||||
|
||||
|
||||
@ -108,7 +111,7 @@ class DatasetMetadataApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(
|
||||
dataset_id_str, metadata_id_str, name, current_user, current_tenant_id
|
||||
db.session(), dataset_id_str, metadata_id_str, name, current_user, current_tenant_id
|
||||
)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 200
|
||||
|
||||
@ -127,7 +130,7 @@ class DatasetMetadataApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
MetadataService.delete_metadata(db.session(), dataset_id_str, metadata_id_str)
|
||||
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
|
||||
return "", 204
|
||||
|
||||
@ -166,9 +169,9 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
MetadataService.enable_built_in_field(db.session(), dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
MetadataService.disable_built_in_field(db.session(), dataset)
|
||||
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
|
||||
return "", 204
|
||||
|
||||
@ -195,7 +198,7 @@ class DocumentMetadataEditApi(Resource):
|
||||
|
||||
metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args, current_user)
|
||||
MetadataService.update_documents_metadata(db.session(), dataset, metadata_args, current_user)
|
||||
|
||||
# Frontend callers only await success and invalidate caches; no response body is consumed.
|
||||
return "", 204
|
||||
|
||||
@ -20,6 +20,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import InstalledApp
|
||||
from services.audio_service import AudioService
|
||||
@ -99,7 +100,13 @@ class ChatTextApi(InstalledAppResource):
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
session=db.session,
|
||||
text=text,
|
||||
voice=voice,
|
||||
message_id=message_id,
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
|
||||
@ -175,7 +175,7 @@ class InstalledAppsListApi(Resource):
|
||||
|
||||
if current_user.current_tenant is None:
|
||||
raise ValueError("current_user.current_tenant must not be None")
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
|
||||
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant, session=db.session)
|
||||
installed_app_list: list[dict[str, Any]] = []
|
||||
for installed_app, app_model in installed_apps:
|
||||
installed_app_list.append(
|
||||
|
||||
@ -27,7 +27,11 @@ from controllers.console.wraps import with_current_user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from fields.message_fields import (
|
||||
ExploreMessageInfiniteScrollPagination,
|
||||
ExploreMessageListItem,
|
||||
SuggestedQuestionsResponse,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models import Account
|
||||
@ -56,7 +60,7 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
GeneratedAppResponse,
|
||||
MessageInfiniteScrollPagination,
|
||||
ExploreMessageInfiniteScrollPagination,
|
||||
ResultResponse,
|
||||
SuggestedQuestionsResponse,
|
||||
)
|
||||
@ -68,7 +72,7 @@ register_response_schema_models(
|
||||
)
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@console_ns.doc(params=query_params_from_model(MessageListQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[MessageInfiniteScrollPagination.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[ExploreMessageInfiniteScrollPagination.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
@ -88,9 +92,9 @@ class MessageListApi(InstalledAppResource):
|
||||
str(args.first_id) if args.first_id else None,
|
||||
args.limit,
|
||||
)
|
||||
adapter = TypeAdapter(MessageListItem)
|
||||
adapter = TypeAdapter(ExploreMessageListItem)
|
||||
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
|
||||
return MessageInfiniteScrollPagination(
|
||||
return ExploreMessageInfiniteScrollPagination(
|
||||
limit=pagination.limit,
|
||||
has_more=pagination.has_more,
|
||||
data=items,
|
||||
|
||||
@ -419,7 +419,13 @@ class TrialChatTextApi(TrialAppResource):
|
||||
app_id = app_model.id
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
session=db.session,
|
||||
text=text,
|
||||
voice=voice,
|
||||
message_id=message_id,
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
|
||||
@ -50,7 +50,7 @@ def get_init_status() -> InitStatusResponse:
|
||||
@only_edition_self_hosted
|
||||
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
|
||||
"""Validate initialization password."""
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
tenant_count = TenantService.get_tenant_count(session=db.session)
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
||||
tenant_count = TenantService.get_tenant_count()
|
||||
tenant_count = TenantService.get_tenant_count(session=db.session)
|
||||
if tenant_count > 0:
|
||||
raise AlreadySetupError()
|
||||
|
||||
@ -94,6 +94,7 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
password=payload.password,
|
||||
ip_address=extract_remote_ip(request),
|
||||
language=payload.language,
|
||||
session=db.session,
|
||||
)
|
||||
|
||||
return SetupResponse(result="success")
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
@ -43,7 +44,7 @@ def socket_connect(sid, environ, auth):
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
user = AccountService.load_logged_in_account(account_id=user_id, session=db.session)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
@ -328,7 +328,7 @@ class AccountNameApi(Resource):
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
updated_account = AccountService.update_account(current_user, session=db.session, name=args.name)
|
||||
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
@ -374,7 +374,7 @@ class AccountAvatarApi(Resource):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountAvatarPayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
|
||||
updated_account = AccountService.update_account(current_user, session=db.session, avatar=args.avatar)
|
||||
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
@ -391,7 +391,9 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceLanguagePayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
|
||||
updated_account = AccountService.update_account(
|
||||
current_user, session=db.session, interface_language=args.interface_language
|
||||
)
|
||||
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
@ -408,7 +410,9 @@ class AccountInterfaceThemeApi(Resource):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceThemePayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
|
||||
updated_account = AccountService.update_account(
|
||||
current_user, session=db.session, interface_theme=args.interface_theme
|
||||
)
|
||||
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
@ -425,7 +429,7 @@ class AccountTimezoneApi(Resource):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountTimezonePayload.model_validate(payload)
|
||||
|
||||
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
|
||||
updated_account = AccountService.update_account(current_user, session=db.session, timezone=args.timezone)
|
||||
|
||||
return _serialize_account(updated_account)
|
||||
|
||||
@ -443,7 +447,8 @@ class AccountPasswordApi(Resource):
|
||||
args = AccountPasswordPayload.model_validate(payload)
|
||||
|
||||
try:
|
||||
AccountService.update_account_password(current_user, args.password, args.new_password)
|
||||
assert args.password is not None
|
||||
AccountService.update_account_password(current_user, args.password, args.new_password, session=db.session)
|
||||
except ServiceCurrentPasswordIncorrectError:
|
||||
raise CurrentPasswordIncorrectError()
|
||||
|
||||
@ -731,7 +736,7 @@ class ChangeEmailResetApi(Resource):
|
||||
if AccountService.is_account_in_freeze(normalized_new_email):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
if not AccountService.check_email_unique(normalized_new_email, session=db.session):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
@ -755,7 +760,9 @@ class ChangeEmailResetApi(Resource):
|
||||
# legitimately verified token.
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
updated_account = AccountService.update_account_email(
|
||||
current_user, email=normalized_new_email, session=db.session
|
||||
)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
email=normalized_new_email,
|
||||
@ -775,6 +782,6 @@ class CheckEmailUnique(Resource):
|
||||
normalized_email = args.email.lower()
|
||||
if AccountService.is_account_in_freeze(normalized_email):
|
||||
raise AccountInFreezeError()
|
||||
if not AccountService.check_email_unique(normalized_email):
|
||||
if not AccountService.check_email_unique(normalized_email, session=db.session):
|
||||
raise EmailAlreadyInUseError()
|
||||
return {"result": "success"}
|
||||
|
||||
@ -131,7 +131,7 @@ def _normalize_invitee_emails(emails: list[str]) -> list[str]:
|
||||
def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int:
|
||||
new_member_count = 0
|
||||
for email in emails:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
if not account:
|
||||
new_member_count += 1
|
||||
continue
|
||||
@ -186,7 +186,7 @@ class MemberListApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant, session=db.session)
|
||||
if dify_config.RBAC_ENABLED:
|
||||
member_ids = [member.id for member in members]
|
||||
member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get(
|
||||
@ -273,6 +273,7 @@ class MemberInviteEmailApi(Resource):
|
||||
language=interface_language,
|
||||
role=invitee_role,
|
||||
inviter=inviter,
|
||||
session=db.session,
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
invitation_results.append(
|
||||
@ -317,7 +318,9 @@ class MemberCancelInviteApi(Resource):
|
||||
abort(404)
|
||||
else:
|
||||
try:
|
||||
TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
|
||||
TenantService.remove_member_from_tenant(
|
||||
current_user.current_tenant, member, current_user, session=db.session
|
||||
)
|
||||
except services.errors.account.CannotOperateSelfError as e:
|
||||
return {"code": "cannot-operate-self", "message": str(e)}, 400
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
@ -360,7 +363,9 @@ class MemberUpdateRoleApi(Resource):
|
||||
|
||||
try:
|
||||
assert member is not None, "Member not found"
|
||||
TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user)
|
||||
TenantService.update_member_role(
|
||||
current_user.current_tenant, member, new_role, current_user, session=db.session
|
||||
)
|
||||
except services.errors.account.CannotOperateSelfError as e:
|
||||
return {"code": "cannot-operate-self", "message": str(e)}, 400
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
@ -387,7 +392,7 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
def get(self, current_user: Account):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant, session=db.session)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
@ -413,7 +418,7 @@ class SendOwnerTransferEmailApi(Resource):
|
||||
# check if the current user is the owner of the workspace
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
|
||||
raise NotOwnerError()
|
||||
|
||||
if args.language is not None and args.language == "zh-Hans":
|
||||
@ -448,7 +453,7 @@ class OwnerTransferCheckApi(Resource):
|
||||
# check if the current user is the owner of the workspace
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
|
||||
raise NotOwnerError()
|
||||
|
||||
user_email = current_user.email
|
||||
@ -494,7 +499,7 @@ class OwnerTransfer(Resource):
|
||||
# check if the current user is the owner of the workspace
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant, session=db.session):
|
||||
raise NotOwnerError()
|
||||
|
||||
if current_user.id == str(member_id):
|
||||
@ -516,12 +521,14 @@ class OwnerTransfer(Resource):
|
||||
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_member(member, current_user.current_tenant):
|
||||
if not TenantService.is_member(member, current_user.current_tenant, session=db.session):
|
||||
raise MemberNotInTenantError()
|
||||
|
||||
try:
|
||||
assert member is not None, "Member not found"
|
||||
TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user)
|
||||
TenantService.update_member_role(
|
||||
current_user.current_tenant, member, "owner", current_user, session=db.session
|
||||
)
|
||||
|
||||
AccountService.send_new_owner_transfer_notify_email(
|
||||
account=member,
|
||||
|
||||
@ -201,21 +201,23 @@ def _legacy_workspace_roles(
|
||||
This keeps the new `/rbac/roles` endpoint compatible with the original
|
||||
Dify role model when enterprise RBAC is disabled.
|
||||
"""
|
||||
|
||||
legacy_roles = [
|
||||
svc.RBACRole(
|
||||
id=role_name,
|
||||
tenant_id="",
|
||||
type=svc.RBACRoleType.WORKSPACE.value,
|
||||
category="global_system_default",
|
||||
name=role_name,
|
||||
description="",
|
||||
is_builtin=True,
|
||||
permission_keys=list(dict.fromkeys(_LEGACY_ROLE_PERMISSION_KEYS[role_name])),
|
||||
role_tag="owner" if role_name == "owner" else "",
|
||||
legacy_roles = []
|
||||
for role_name in ("owner", "admin", "editor", "normal", "dataset_operator"):
|
||||
if not dify_config.DATASET_OPERATOR_ENABLED and role_name == "dataset_operator":
|
||||
continue
|
||||
legacy_roles.append(
|
||||
svc.RBACRole(
|
||||
id=role_name,
|
||||
tenant_id="",
|
||||
type=svc.RBACRoleType.WORKSPACE.value,
|
||||
category="global_system_default",
|
||||
name=role_name,
|
||||
description="",
|
||||
is_builtin=True,
|
||||
permission_keys=list(dict.fromkeys(_LEGACY_ROLE_PERMISSION_KEYS[role_name])),
|
||||
role_tag="owner" if role_name == "owner" else "",
|
||||
)
|
||||
)
|
||||
for role_name in ("owner", "admin", "editor", "normal", "dataset_operator")
|
||||
]
|
||||
|
||||
if not include_owner:
|
||||
legacy_roles = [r for r in legacy_roles if r.name != "owner"]
|
||||
|
||||
@ -325,10 +325,10 @@ class TenantApi(Resource):
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenants = TenantService.get_join_tenants(current_user, session=db.session)
|
||||
# if there is any tenant, switch to the first one
|
||||
if len(tenants) > 0:
|
||||
TenantService.switch_tenant(current_user, tenants[0].id)
|
||||
TenantService.switch_tenant(current_user, tenants[0].id, session=db.session)
|
||||
tenant = tenants[0]
|
||||
# else, raise Unauthorized
|
||||
else:
|
||||
@ -351,7 +351,7 @@ class SwitchWorkspaceApi(Resource):
|
||||
|
||||
# check if tenant_id is valid, 403 if not
|
||||
try:
|
||||
TenantService.switch_tenant(current_user, args.tenant_id)
|
||||
TenantService.switch_tenant(current_user, args.tenant_id, session=db.session)
|
||||
except Exception:
|
||||
raise AccountNotLinkTenantError("Account not link tenant")
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from controllers.common.schema import register_response_schema_models, register_
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from core.workflow.nodes.knowledge_retrieval import exc as retrieval_exc
|
||||
from extensions.ext_database import db
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.entities.knowledge_retrieval_inner import InnerKnowledgeRetrieveRequest, InnerKnowledgeRetrieveResponse
|
||||
from services.errors.knowledge_retrieval import ExternalKnowledgeRetrievalError, InnerKnowledgeRetrievalServiceError
|
||||
@ -81,7 +82,7 @@ class InnerKnowledgeRetrieveApi(Resource):
|
||||
) from exc
|
||||
|
||||
try:
|
||||
response = InnerKnowledgeRetrievalService().retrieve(payload)
|
||||
response = InnerKnowledgeRetrievalService().retrieve(payload, session=db.session)
|
||||
except InnerKnowledgeRetrievalServiceError as exc:
|
||||
raise InnerKnowledgeRetrievalHttpError(
|
||||
error_code=exc.error_code,
|
||||
|
||||
@ -47,7 +47,7 @@ class EnterpriseWorkspace(Resource):
|
||||
if account is None:
|
||||
return {"message": "owner account not found."}, 404
|
||||
|
||||
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
|
||||
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session)
|
||||
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
@ -84,7 +84,7 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
def post(self):
|
||||
args = WorkspaceOwnerlessPayload.model_validate(inner_api_ns.payload or {})
|
||||
|
||||
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True)
|
||||
tenant = TenantService.create_tenant(args.name, is_from_dashboard=True, session=db.session)
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
|
||||
@ -128,7 +128,7 @@ class WorkspaceSwitchApi(Resource):
|
||||
account = _load_account(auth_data.account_id)
|
||||
|
||||
try:
|
||||
TenantService.switch_tenant(account, workspace_id)
|
||||
TenantService.switch_tenant(account, workspace_id, session=db.session)
|
||||
except AccountNotLinkTenantError:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
@ -152,7 +152,7 @@ class WorkspaceMembersApi(Resource):
|
||||
@accepts(query=MemberListQuery)
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData, query: MemberListQuery):
|
||||
tenant = _load_tenant(workspace_id)
|
||||
members = TenantService.get_tenant_members(tenant)
|
||||
members = TenantService.get_tenant_members(tenant, session=db.session)
|
||||
total = len(members)
|
||||
start = (query.page - 1) * query.limit
|
||||
page_items = members[start : start + query.limit]
|
||||
@ -184,6 +184,7 @@ class WorkspaceMembersApi(Resource):
|
||||
language=None,
|
||||
role=body.role,
|
||||
inviter=inviter,
|
||||
session=db.session,
|
||||
)
|
||||
except AccountAlreadyInTenantError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
@ -193,7 +194,7 @@ class WorkspaceMembersApi(Resource):
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
normalized_email = body.email.lower()
|
||||
member = AccountService.get_account_by_email_with_case_fallback(normalized_email)
|
||||
member = AccountService.get_account_by_email_with_case_fallback(db.session, normalized_email)
|
||||
if member is None:
|
||||
# invite_new_member just created or fetched this account.
|
||||
raise RuntimeError("invited member missing from DB after invite")
|
||||
@ -232,7 +233,7 @@ class WorkspaceMemberApi(Resource):
|
||||
raise NotFound("member not found")
|
||||
|
||||
try:
|
||||
TenantService.remove_member_from_tenant(tenant, member, operator)
|
||||
TenantService.remove_member_from_tenant(tenant, member, operator, session=db.session)
|
||||
except CannotOperateSelfError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
@ -266,7 +267,7 @@ class WorkspaceMemberRoleApi(Resource):
|
||||
raise NotFound("member not found")
|
||||
|
||||
try:
|
||||
TenantService.update_member_role(tenant, member, body.role, operator)
|
||||
TenantService.update_member_role(tenant, member, body.role, operator, session=db.session)
|
||||
except CannotOperateSelfError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
|
||||
@ -10,6 +10,7 @@ from controllers.common.schema import query_params_from_model, register_response
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import Annotation, AnnotationList
|
||||
from fields.base import ResponseModel
|
||||
@ -281,7 +282,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
"""Update an existing annotation."""
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
update_args, app_model.id, str(annotation_id), db.session
|
||||
)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -310,5 +313,5 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, annotation_id: UUID):
|
||||
"""Delete an annotation."""
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id), db.session)
|
||||
return "", 204
|
||||
|
||||
@ -23,6 +23,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.schema import binary_response, expect_with_user, multipart_file_params
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, EndUser
|
||||
from services.audio_service import AudioService
|
||||
@ -177,7 +178,12 @@ class TextApi(Resource):
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
app_model=app_model,
|
||||
session=db.session,
|
||||
text=text,
|
||||
voice=voice,
|
||||
end_user=end_user.external_user_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -8,6 +8,7 @@ from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_model, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataActionResponse,
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
@ -85,7 +86,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
metadata = MetadataService.create_metadata(db.session(), dataset_id_str, metadata_args)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 201
|
||||
|
||||
@service_api_ns.doc(
|
||||
@ -118,7 +119,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
metadata = MetadataService.get_dataset_metadatas(dataset)
|
||||
metadata = MetadataService.get_dataset_metadatas(db.session(), dataset)
|
||||
return dump_response(DatasetMetadataListResponse, metadata), 200
|
||||
|
||||
|
||||
@ -158,7 +159,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
|
||||
metadata = MetadataService.update_metadata_name(db.session(), dataset_id_str, metadata_id_str, payload.name)
|
||||
return dump_response(DatasetMetadataResponse, metadata), 200
|
||||
|
||||
@service_api_ns.doc(
|
||||
@ -193,7 +194,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
MetadataService.delete_metadata(db.session(), dataset_id_str, metadata_id_str)
|
||||
return "", 204
|
||||
|
||||
|
||||
@ -263,9 +264,9 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
|
||||
match action:
|
||||
case "enable":
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
MetadataService.enable_built_in_field(db.session(), dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
MetadataService.disable_built_in_field(db.session(), dataset)
|
||||
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
|
||||
|
||||
|
||||
@ -309,6 +310,6 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
|
||||
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
MetadataService.update_documents_metadata(db.session(), dataset, metadata_args)
|
||||
|
||||
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
|
||||
|
||||
@ -22,6 +22,7 @@ from controllers.web.error import (
|
||||
)
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
@ -130,7 +131,12 @@ class TextApi(WebApiResource):
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
app_model=app_model,
|
||||
session=db.session,
|
||||
text=text,
|
||||
voice=voice,
|
||||
end_user=end_user.external_user_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@ -32,6 +32,7 @@ from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -202,6 +203,12 @@ class ChatApi(WebApiResource):
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
# Eagerly validate conversation to avoid hanging on invalid conversation_id
|
||||
if payload.conversation_id:
|
||||
ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=payload.conversation_id, user=end_user
|
||||
)
|
||||
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming
|
||||
)
|
||||
|
||||
@ -69,7 +69,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, request_email)
|
||||
if account is None:
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
@ -168,7 +168,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
|
||||
if account:
|
||||
account = db.session.merge(account)
|
||||
|
||||
@ -41,6 +41,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueReasoningChunkEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
@ -62,6 +63,7 @@ from core.app.entities.task_entities import (
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
ReasoningChunkStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowTaskState,
|
||||
@ -473,6 +475,17 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
|
||||
)
|
||||
|
||||
# Collect terminal reasoning (separated mode) per LLM node id for persistence. This is the
|
||||
# authoritative source (outputs.reasoning_content), decoupled from the live delta stream.
|
||||
# Accumulate across iteration/loop passes (same node_id) to match the live stream, which
|
||||
# appends every pass under the same key — overwriting would keep only the last pass.
|
||||
if event.node_type == BuiltinNodeTypes.LLM:
|
||||
reasoning_content = (event.outputs or {}).get("reasoning_content")
|
||||
if isinstance(reasoning_content, str) and reasoning_content:
|
||||
self._task_state.metadata.reasoning[event.node_id] = (
|
||||
self._task_state.metadata.reasoning.get(event.node_id, "") + reasoning_content
|
||||
)
|
||||
|
||||
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -535,6 +548,27 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
|
||||
def _handle_reasoning_chunk_event(
|
||||
self, event: QueueReasoningChunkEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle out-of-band reasoning chunk events.
|
||||
|
||||
Pure emit: reasoning is streamed on its own channel and never written to the
|
||||
answer. The terminal marker (is_final) may carry an empty reasoning string, in
|
||||
which case it is still forwarded as the "thinking finished" signal.
|
||||
"""
|
||||
if not event.reasoning and not event.is_final:
|
||||
return
|
||||
yield ReasoningChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=ReasoningChunkStreamResponse.Data(
|
||||
message_id=self._message_id,
|
||||
reasoning=event.reasoning,
|
||||
node_id=event.from_node_id,
|
||||
is_final=event.is_final,
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_iteration_start_event(
|
||||
self, event: QueueIterationStartEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
@ -872,6 +906,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueuePingEvent: self._handle_ping_event,
|
||||
QueueErrorEvent: self._handle_error_event,
|
||||
QueueTextChunkEvent: self._handle_text_chunk_event,
|
||||
QueueReasoningChunkEvent: self._handle_reasoning_chunk_event,
|
||||
# Workflow events
|
||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
|
||||
@ -197,7 +197,7 @@ class AgentAppRuntimeRequestBuilder:
|
||||
def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str:
|
||||
"""Return the transport plugin id expected by plugin-daemon headers."""
|
||||
if plugin_id.count("/") == 1:
|
||||
return plugin_id
|
||||
return plugin_id.split(":", 1)[0].split("@", 1)[0]
|
||||
if plugin_id:
|
||||
return ModelProviderID(plugin_id).plugin_id
|
||||
return ModelProviderID(model_provider).plugin_id
|
||||
|
||||
@ -31,6 +31,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueReasoningChunkEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
@ -47,6 +48,7 @@ from core.app.entities.task_entities import (
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
PingStreamResponse,
|
||||
ReasoningChunkStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
@ -571,6 +573,22 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
|
||||
|
||||
def _handle_reasoning_chunk_event(
|
||||
self, event: QueueReasoningChunkEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle reasoning chunk events."""
|
||||
# is_final with empty reasoning is still forwarded as the "thinking finished" signal
|
||||
if not event.reasoning and not event.is_final:
|
||||
return
|
||||
yield ReasoningChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=ReasoningChunkStreamResponse.Data(
|
||||
reasoning=event.reasoning,
|
||||
node_id=event.from_node_id,
|
||||
is_final=event.is_final,
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
yield self._workflow_response_converter.handle_agent_log(
|
||||
@ -600,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueuePingEvent: self._handle_ping_event,
|
||||
QueueErrorEvent: self._handle_error_event,
|
||||
QueueTextChunkEvent: self._handle_text_chunk_event,
|
||||
QueueReasoningChunkEvent: self._handle_reasoning_chunk_event,
|
||||
# Workflow events
|
||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
|
||||
@ -24,6 +24,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueReasoningChunkEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
@ -74,6 +75,7 @@ from graphon.graph_events import (
|
||||
NodeRunLoopNextEvent,
|
||||
NodeRunLoopStartedEvent,
|
||||
NodeRunLoopSucceededEvent,
|
||||
NodeRunReasoningChunkEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
@ -576,6 +578,16 @@ class WorkflowBasedAppRunner:
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
case NodeRunReasoningChunkEvent():
|
||||
self._publish_event(
|
||||
QueueReasoningChunkEvent(
|
||||
reasoning=event.chunk,
|
||||
from_node_id=event.node_id,
|
||||
is_final=event.is_final,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
case NodeRunRetrieverResourceEvent():
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
|
||||
@ -40,6 +40,7 @@ class QueueEvent(StrEnum):
|
||||
NODE_FAILED = "node_failed"
|
||||
NODE_EXCEPTION = "node_exception"
|
||||
RETRIEVER_RESOURCES = "retriever_resources"
|
||||
REASONING_CHUNK = "reasoning_chunk"
|
||||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
@ -197,6 +198,26 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueReasoningChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueReasoningChunkEvent entity
|
||||
|
||||
Out-of-band reasoning (chain-of-thought) delta from an LLM node in "separated"
|
||||
mode. It never touches the answer; it is emitted on a dedicated channel.
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.REASONING_CHUNK
|
||||
reasoning: str
|
||||
from_node_id: str | None = None
|
||||
"""id of the LLM node that produced this reasoning"""
|
||||
is_final: bool = False
|
||||
"""marks the terminal reasoning chunk for the node run (may carry empty reasoning)"""
|
||||
in_iteration_id: str | None = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageEvent entity
|
||||
|
||||
@ -27,6 +27,9 @@ class TaskStateMetadata(BaseModel):
|
||||
annotation_reply: AnnotationReply | None = None
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
|
||||
usage: LLMUsage | None = None
|
||||
reasoning: dict[str, str] = Field(default_factory=dict)
|
||||
"""reasoning_content per LLM node id (separated mode), accumulated across iteration/loop
|
||||
passes for that node; persisted to message_metadata"""
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
@ -85,6 +88,7 @@ class StreamEvent(StrEnum):
|
||||
LOOP_COMPLETED = "loop_completed"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
REASONING_CHUNK = "reasoning_chunk"
|
||||
AGENT_LOG = "agent_log"
|
||||
HUMAN_INPUT_REQUIRED = "human_input_required"
|
||||
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
|
||||
@ -726,6 +730,29 @@ class TextChunkStreamResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class ReasoningChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
ReasoningChunkStreamResponse entity
|
||||
|
||||
Out-of-band reasoning (chain-of-thought) delta, parallel to text_chunk. Only
|
||||
emitted in "separated" mode; the answer/message stream stays free of <think>.
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
# chat apps set this; workflow runs have no message
|
||||
message_id: str | None = None
|
||||
reasoning: str
|
||||
node_id: str | None = None
|
||||
is_final: bool = False
|
||||
|
||||
event: StreamEvent = StreamEvent.REASONING_CHUNK
|
||||
data: Data
|
||||
|
||||
|
||||
class TextReplaceStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextReplaceStreamResponse entity
|
||||
|
||||
@ -226,7 +226,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
all_multimodal_documents.append(file_document)
|
||||
doc.attachments = attachments
|
||||
else:
|
||||
account = AccountService.load_user(document.created_by)
|
||||
account = AccountService.load_user(document.created_by, db.session)
|
||||
if not account:
|
||||
raise ValueError("Invalid account")
|
||||
doc.attachments = self._get_content_files(doc, current_user=account)
|
||||
|
||||
@ -291,7 +291,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
attachments.append(file_document)
|
||||
doc.attachments = attachments
|
||||
else:
|
||||
account = AccountService.load_user(document.created_by)
|
||||
account = AccountService.load_user(document.created_by, db.session)
|
||||
if not account:
|
||||
raise ValueError("Invalid account")
|
||||
doc.attachments = self._get_content_files(doc, current_user=account)
|
||||
|
||||
@ -265,7 +265,7 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
def _plugin_daemon_plugin_id(*, plugin_id: str, model_provider: str) -> str:
|
||||
"""Return the transport plugin id expected by plugin-daemon headers."""
|
||||
if plugin_id.count("/") == 1:
|
||||
return plugin_id
|
||||
return plugin_id.split(":", 1)[0].split("@", 1)[0]
|
||||
if plugin_id:
|
||||
return ModelProviderID(plugin_id).plugin_id
|
||||
return ModelProviderID(model_provider).plugin_id
|
||||
|
||||
@ -124,6 +124,7 @@ else
|
||||
exec python -m app
|
||||
else
|
||||
exec gunicorn \
|
||||
--no-control-socket \
|
||||
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
|
||||
--workers ${SERVER_WORKER_AMOUNT:-1} \
|
||||
--worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \
|
||||
|
||||
@ -84,7 +84,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
|
||||
if not user_id:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, session=db.session)
|
||||
return logged_in_account
|
||||
elif request.blueprint == "openapi":
|
||||
# Account-branch device-flow approval routes (approve / deny /
|
||||
@ -103,7 +103,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
|
||||
source = decoded.get("token_source")
|
||||
if source or not user_id:
|
||||
return None
|
||||
return AccountService.load_logged_in_account(account_id=user_id)
|
||||
return AccountService.load_logged_in_account(account_id=user_id, session=db.session)
|
||||
elif request.blueprint == "web":
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
||||
|
||||
@ -77,6 +77,13 @@ class WebMessageListItem(MessageListItem):
|
||||
)
|
||||
|
||||
|
||||
class ExploreMessageListItem(MessageListItem):
|
||||
metadata: JSONValueType | None = Field(
|
||||
default=None,
|
||||
validation_alias="message_metadata_dict",
|
||||
)
|
||||
|
||||
|
||||
class MessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
@ -89,6 +96,12 @@ class WebMessageInfiniteScrollPagination(ResponseModel):
|
||||
data: list[WebMessageListItem]
|
||||
|
||||
|
||||
class ExploreMessageInfiniteScrollPagination(ResponseModel):
|
||||
limit: int
|
||||
has_more: bool
|
||||
data: list[ExploreMessageListItem]
|
||||
|
||||
|
||||
class SavedMessageItem(ResponseModel):
|
||||
id: str
|
||||
inputs: dict[str, JSONValueType]
|
||||
|
||||
@ -16,7 +16,7 @@ def valid_password(password):
|
||||
raise ValueError("Password must contain letters and numbers, and the length must be at least 8 characters.")
|
||||
|
||||
|
||||
def hash_password(password_str, salt_byte):
|
||||
def hash_password(password_str: str, salt_byte: bytes):
|
||||
dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000)
|
||||
return binascii.hexlify(dk)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""add dataset_id to api_tokens
|
||||
|
||||
Revision ID: e4f8a2c61d35
|
||||
Revises: c8f4a6b2d3e1
|
||||
Revises: d9e8f7a6b5c4
|
||||
Create Date: 2026-06-17 09:00:00.000000
|
||||
|
||||
Reintroduces the nullable `dataset_id` column on `api_tokens` (it was dropped in
|
||||
@ -23,7 +23,7 @@ import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e4f8a2c61d35"
|
||||
down_revision = "c8f4a6b2d3e1"
|
||||
down_revision = "d9e8f7a6b5c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
@ -0,0 +1,26 @@
|
||||
"""add cloud only flag to recommended apps
|
||||
|
||||
Revision ID: d9e8f7a6b5c4
|
||||
Revises: c8f4a6b2d3e1
|
||||
Create Date: 2026-06-23 18:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d9e8f7a6b5c4"
|
||||
down_revision = "c8f4a6b2d3e1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("recommended_apps", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("is_cloud_only", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("recommended_apps", schema=None) as batch_op:
|
||||
batch_op.drop_column("is_cloud_only")
|
||||
@ -925,6 +925,9 @@ class RecommendedApp(TypeBase):
|
||||
is_learn_dify: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false"), default=False
|
||||
)
|
||||
is_cloud_only: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=sa.text("false"), default=False
|
||||
)
|
||||
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
|
||||
@ -3807,6 +3807,26 @@ Submit human input form preview for workflow
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Workflow agent composer candidates | **application/json**: [AgentComposerCandidatesResponse](#agentcomposercandidatesresponse)<br> |
|
||||
|
||||
### [POST] /apps/{app_id}/workflows/draft/nodes/{node_id}/agent-composer/copy-from-roster
|
||||
#### Parameters
|
||||
|
||||
| Name | Located in | Description | Required | Schema |
|
||||
| ---- | ---------- | ----------- | -------- | ------ |
|
||||
| app_id | path | | Yes | string (uuid) |
|
||||
| node_id | path | | Yes | string |
|
||||
|
||||
#### Request Body
|
||||
|
||||
| Required | Schema |
|
||||
| -------- | ------ |
|
||||
| Yes | **application/json**: [WorkflowComposerCopyFromRosterPayload](#workflowcomposercopyfromrosterpayload)<br> |
|
||||
|
||||
#### Responses
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Workflow roster agent copied to inline agent | **application/json**: [WorkflowAgentComposerResponse](#workflowagentcomposerresponse)<br> |
|
||||
|
||||
### [POST] /apps/{app_id}/workflows/draft/nodes/{node_id}/agent-composer/impact
|
||||
#### Parameters
|
||||
|
||||
@ -6574,7 +6594,7 @@ Request body:
|
||||
|
||||
| Code | Description | Schema |
|
||||
| ---- | ----------- | ------ |
|
||||
| 200 | Success | **application/json**: [MessageInfiniteScrollPagination](#messageinfinitescrollpagination)<br> |
|
||||
| 200 | Success | **application/json**: [ExploreMessageInfiniteScrollPagination](#exploremessageinfinitescrollpagination)<br> |
|
||||
|
||||
### [POST] /installed-apps/{installed_app_id}/messages/{message_id}/feedbacks
|
||||
#### Parameters
|
||||
@ -14386,9 +14406,14 @@ Button styles for user actions.
|
||||
| agent_soul | [AgentSoulConfig](#agentsoulconfig) | | No |
|
||||
| binding | [ComposerBindingPayload](#composerbindingpayload) | | No |
|
||||
| client_revision_id | string | | No |
|
||||
| description | string | | No |
|
||||
| icon | string | | No |
|
||||
| icon_background | string | | No |
|
||||
| icon_type | [AgentIconType](#agenticontype) | | No |
|
||||
| idempotency_key | string | | No |
|
||||
| new_agent_name | string | | No |
|
||||
| node_job | [WorkflowNodeJobConfig](#workflownodejobconfig) | | No |
|
||||
| role | string | | No |
|
||||
| save_strategy | [ComposerSaveStrategy](#composersavestrategy) | | Yes |
|
||||
| soul_lock | [ComposerSoulLockPayload](#composersoullockpayload) | | No |
|
||||
| variant | [ComposerVariant](#composervariant) | | Yes |
|
||||
@ -15997,6 +16022,34 @@ Request payload for bulk downloading documents as a zip archive.
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| tool_icons | object | | No |
|
||||
|
||||
#### ExploreMessageInfiniteScrollPagination
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| data | [ [ExploreMessageListItem](#exploremessagelistitem) ] | | Yes |
|
||||
| has_more | boolean | | Yes |
|
||||
| limit | integer | | Yes |
|
||||
|
||||
#### ExploreMessageListItem
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| agent_thoughts | [ [AgentThought](#agentthought) ] | | Yes |
|
||||
| answer | string | | Yes |
|
||||
| conversation_id | string | | Yes |
|
||||
| created_at | integer | | No |
|
||||
| error | string | | No |
|
||||
| extra_contents | [ [HumanInputContent](#humaninputcontent) ] | | Yes |
|
||||
| feedback | [SimpleFeedback](#simplefeedback) | | No |
|
||||
| id | string | | Yes |
|
||||
| inputs | object | | Yes |
|
||||
| message_files | [ [MessageFile](#messagefile) ] | | Yes |
|
||||
| metadata | [JSONValueType](#jsonvaluetype) | | No |
|
||||
| parent_message_id | string | | No |
|
||||
| query | string | | Yes |
|
||||
| retriever_resources | [ [RetrieverResource](#retrieverresource) ] | | Yes |
|
||||
| status | string | | Yes |
|
||||
|
||||
#### ExternalApiTemplateListQuery
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
@ -17071,14 +17124,6 @@ Enum class for large language model mode.
|
||||
| upload_file_id | string | | No |
|
||||
| url | string | | No |
|
||||
|
||||
#### MessageInfiniteScrollPagination
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| data | [ [MessageListItem](#messagelistitem) ] | | Yes |
|
||||
| has_more | boolean | | Yes |
|
||||
| limit | integer | | Yes |
|
||||
|
||||
#### MessageInfiniteScrollPaginationResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
@ -17087,25 +17132,6 @@ Enum class for large language model mode.
|
||||
| has_more | boolean | | Yes |
|
||||
| limit | integer | | Yes |
|
||||
|
||||
#### MessageListItem
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| agent_thoughts | [ [AgentThought](#agentthought) ] | | Yes |
|
||||
| answer | string | | Yes |
|
||||
| conversation_id | string | | Yes |
|
||||
| created_at | integer | | No |
|
||||
| error | string | | No |
|
||||
| extra_contents | [ [HumanInputContent](#humaninputcontent) ] | | Yes |
|
||||
| feedback | [SimpleFeedback](#simplefeedback) | | No |
|
||||
| id | string | | Yes |
|
||||
| inputs | object | | Yes |
|
||||
| message_files | [ [MessageFile](#messagefile) ] | | Yes |
|
||||
| parent_message_id | string | | No |
|
||||
| query | string | | Yes |
|
||||
| retriever_resources | [ [RetrieverResource](#retrieverresource) ] | | Yes |
|
||||
| status | string | | Yes |
|
||||
|
||||
#### MessageListQuery
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
@ -19580,6 +19606,7 @@ Model class for provider system configuration response.
|
||||
| enable_email_code_login | boolean | | Yes |
|
||||
| enable_email_password_login | boolean, <br>**Default:** true | | Yes |
|
||||
| enable_explore_banner | boolean | | Yes |
|
||||
| enable_learn_app | boolean, <br>**Default:** true | | Yes |
|
||||
| enable_marketplace | boolean | | Yes |
|
||||
| enable_social_oauth_login | boolean | | Yes |
|
||||
| enable_trial_app | boolean | | Yes |
|
||||
@ -20560,6 +20587,14 @@ How a workflow node is bound to an Agent.
|
||||
| position_x | number | Comment X position | No |
|
||||
| position_y | number | Comment Y position | No |
|
||||
|
||||
#### WorkflowComposerCopyFromRosterPayload
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| idempotency_key | string | | No |
|
||||
| source_agent_id | string | | Yes |
|
||||
| source_snapshot_id | string | | No |
|
||||
|
||||
#### WorkflowConversationVariableResponse
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
|
||||
@ -1603,6 +1603,7 @@ Default configuration for form inputs.
|
||||
| enable_email_code_login | boolean | | Yes |
|
||||
| enable_email_password_login | boolean, <br>**Default:** true | | Yes |
|
||||
| enable_explore_banner | boolean | | Yes |
|
||||
| enable_learn_app | boolean, <br>**Default:** true | | Yes |
|
||||
| enable_marketplace | boolean | | Yes |
|
||||
| enable_social_oauth_login | boolean | | Yes |
|
||||
| enable_trial_app | boolean | | Yes |
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
name = "dify-trace-langsmith"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"langsmith==0.8.5",
|
||||
"langsmith==0.8.18",
|
||||
]
|
||||
description = "Dify ops tracing provider (LangSmith)."
|
||||
|
||||
|
||||
@ -181,30 +181,34 @@ class TestTencentDataTrace:
|
||||
mock_trace_utils.convert_to_trace_id.return_value = 123
|
||||
mock_trace_utils.create_link.return_value = "link"
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"):
|
||||
with patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc:
|
||||
with patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur:
|
||||
mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()]
|
||||
with (
|
||||
patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"),
|
||||
patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc,
|
||||
patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur,
|
||||
):
|
||||
mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()]
|
||||
|
||||
tencent_data_trace.workflow_trace(trace_info)
|
||||
tencent_data_trace.workflow_trace(trace_info)
|
||||
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id")
|
||||
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
|
||||
mock_span_builder.build_workflow_spans.assert_called_once()
|
||||
assert tencent_data_trace.trace_client.add_span.call_count == 2
|
||||
mock_proc.assert_called_once_with(trace_info, 123)
|
||||
mock_dur.assert_called_once_with(trace_info)
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id")
|
||||
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
|
||||
mock_span_builder.build_workflow_spans.assert_called_once()
|
||||
assert tencent_data_trace.trace_client.add_span.call_count == 2
|
||||
mock_proc.assert_called_once_with(trace_info, 123)
|
||||
mock_dur.assert_called_once_with(trace_info)
|
||||
|
||||
def test_workflow_trace_exception(self, tencent_data_trace):
|
||||
def test_workflow_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.workflow_run_id = "run-id"
|
||||
|
||||
with patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
|
||||
with (
|
||||
patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
|
||||
),
|
||||
caplog.at_level(logging.ERROR),
|
||||
):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.workflow_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace")
|
||||
tencent_data_trace.workflow_trace(trace_info)
|
||||
assert "[Tencent APM] Failed to process workflow trace" in caplog.text
|
||||
|
||||
def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
@ -214,29 +218,33 @@ class TestTencentDataTrace:
|
||||
mock_trace_utils.convert_to_trace_id.return_value = 123
|
||||
mock_trace_utils.create_link.return_value = "link"
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"):
|
||||
with patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics:
|
||||
with patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur:
|
||||
mock_span_builder.build_message_span.return_value = MagicMock()
|
||||
with (
|
||||
patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"),
|
||||
patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics,
|
||||
patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur,
|
||||
):
|
||||
mock_span_builder.build_message_span.return_value = MagicMock()
|
||||
|
||||
tencent_data_trace.message_trace(trace_info)
|
||||
tencent_data_trace.message_trace(trace_info)
|
||||
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
|
||||
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
|
||||
mock_span_builder.build_message_span.assert_called_once()
|
||||
tencent_data_trace.trace_client.add_span.assert_called_once()
|
||||
mock_metrics.assert_called_once_with(trace_info)
|
||||
mock_dur.assert_called_once_with(trace_info)
|
||||
mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id")
|
||||
mock_trace_utils.create_link.assert_called_once_with("parent-trace-id")
|
||||
mock_span_builder.build_message_span.assert_called_once()
|
||||
tencent_data_trace.trace_client.add_span.assert_called_once()
|
||||
mock_metrics.assert_called_once_with(trace_info)
|
||||
mock_dur.assert_called_once_with(trace_info)
|
||||
|
||||
def test_message_trace_exception(self, tencent_data_trace):
|
||||
def test_message_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
|
||||
with patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
|
||||
with (
|
||||
patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error")
|
||||
),
|
||||
caplog.at_level(logging.ERROR),
|
||||
):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.message_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace")
|
||||
tencent_data_trace.message_trace(trace_info)
|
||||
assert "[Tencent APM] Failed to process message trace" in caplog.text
|
||||
|
||||
def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
@ -259,16 +267,18 @@ class TestTencentDataTrace:
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
tencent_data_trace.trace_client.add_span.assert_not_called()
|
||||
|
||||
def test_tool_trace_exception(self, tencent_data_trace):
|
||||
def test_tool_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=ToolTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
with patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
|
||||
with (
|
||||
patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
|
||||
),
|
||||
caplog.at_level(logging.ERROR),
|
||||
):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace")
|
||||
tencent_data_trace.tool_trace(trace_info)
|
||||
assert "[Tencent APM] Failed to process tool trace" in caplog.text
|
||||
|
||||
def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
@ -291,29 +301,34 @@ class TestTencentDataTrace:
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
tencent_data_trace.trace_client.add_span.assert_not_called()
|
||||
|
||||
def test_dataset_retrieval_trace_exception(self, tencent_data_trace):
|
||||
def test_dataset_retrieval_trace_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=DatasetRetrievalTraceInfo)
|
||||
trace_info.message_id = "msg-id"
|
||||
|
||||
with patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
|
||||
with (
|
||||
patch(
|
||||
"dify_trace_tencent.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error")
|
||||
),
|
||||
caplog.at_level(logging.ERROR),
|
||||
):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace")
|
||||
tencent_data_trace.dataset_retrieval_trace(trace_info)
|
||||
assert "[Tencent APM] Failed to process dataset retrieval trace" in caplog.text
|
||||
|
||||
def test_suggested_question_trace(self, tencent_data_trace):
|
||||
def test_suggested_question_trace(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.info") as mock_log:
|
||||
with caplog.at_level(logging.INFO):
|
||||
tencent_data_trace.suggested_question_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace")
|
||||
assert "[Tencent APM] Processing suggested question trace" in caplog.text
|
||||
|
||||
def test_suggested_question_trace_exception(self, tencent_data_trace):
|
||||
def test_suggested_question_trace_exception(
|
||||
self, tencent_data_trace, monkeypatch, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
trace_info = MagicMock(spec=SuggestedQuestionTraceInfo)
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.info", side_effect=Exception("error")):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace.suggested_question_trace(trace_info)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace")
|
||||
target_logger = logging.getLogger("dify_trace_tencent.tencent_trace")
|
||||
monkeypatch.setattr(target_logger, "info", MagicMock(side_effect=Exception("error")))
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace.suggested_question_trace(trace_info)
|
||||
assert "[Tencent APM] Failed to process suggested question trace" in caplog.text
|
||||
|
||||
def test_process_workflow_nodes(self, tencent_data_trace, mock_trace_utils):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
@ -327,35 +342,42 @@ class TestTencentDataTrace:
|
||||
node2.id = "n2"
|
||||
node2.node_type = BuiltinNodeTypes.TOOL
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]):
|
||||
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]):
|
||||
with patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics:
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
with (
|
||||
patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]),
|
||||
patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]),
|
||||
patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics,
|
||||
):
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
|
||||
assert tencent_data_trace.trace_client.add_span.call_count == 2
|
||||
mock_metrics.assert_called_once_with(node1)
|
||||
assert tencent_data_trace.trace_client.add_span.call_count == 2
|
||||
mock_metrics.assert_called_once_with(node1)
|
||||
|
||||
def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils):
|
||||
def test_process_workflow_nodes_node_exception(
|
||||
self, tencent_data_trace, mock_trace_utils, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
mock_trace_utils.convert_to_span_id.return_value = 111
|
||||
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.id = "n1"
|
||||
|
||||
with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]):
|
||||
with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
# The exception should be caught by the outer handler since convert_to_span_id is called first
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
|
||||
with (
|
||||
patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]),
|
||||
patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")),
|
||||
caplog.at_level(logging.ERROR),
|
||||
):
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
assert "[Tencent APM] Failed to process workflow nodes" in caplog.text
|
||||
|
||||
def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils):
|
||||
def test_process_workflow_nodes_exception(
|
||||
self, tencent_data_trace, mock_trace_utils, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error")
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace._process_workflow_nodes(trace_info, 123)
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes")
|
||||
assert "[Tencent APM] Failed to process workflow nodes" in caplog.text
|
||||
|
||||
def test_build_workflow_node_span(self, tencent_data_trace, mock_span_builder):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
@ -377,16 +399,18 @@ class TestTencentDataTrace:
|
||||
assert result == "span"
|
||||
builder_method.assert_called_once_with(123, 456, trace_info, node)
|
||||
|
||||
def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder):
|
||||
def test_build_workflow_node_span_exception(
|
||||
self, tencent_data_trace, mock_span_builder, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.node_type = BuiltinNodeTypes.LLM
|
||||
node.id = "n1"
|
||||
mock_span_builder.build_workflow_llm_span.side_effect = Exception("error")
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456)
|
||||
assert result is None
|
||||
mock_log.assert_called_once()
|
||||
assert result is None
|
||||
assert len([r for r in caplog.records if r.levelno == logging.DEBUG]) >= 1
|
||||
|
||||
def test_get_workflow_node_executions(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
@ -419,16 +443,16 @@ class TestTencentDataTrace:
|
||||
assert results == mock_executions
|
||||
account.set_tenant_id.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace):
|
||||
def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {}
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
results = tencent_data_trace._get_workflow_node_executions(trace_info)
|
||||
assert results == []
|
||||
mock_log.assert_called_once()
|
||||
assert results == []
|
||||
assert len([r for r in caplog.records if r.levelno == logging.ERROR]) >= 1
|
||||
|
||||
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace):
|
||||
def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.metadata = {"app_id": "app-1"}
|
||||
|
||||
@ -439,23 +463,25 @@ class TestTencentDataTrace:
|
||||
session = mock_session_ctx.return_value.__enter__.return_value
|
||||
session.scalar.return_value = None
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
results = tencent_data_trace._get_workflow_node_executions(trace_info)
|
||||
assert results == []
|
||||
mock_log.assert_called_once()
|
||||
assert results == []
|
||||
assert len([r for r in caplog.records if r.levelno == logging.ERROR]) >= 1
|
||||
|
||||
def test_get_user_id_workflow(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.tenant_id = "tenant-1"
|
||||
trace_info.metadata = {"user_id": "user-1"}
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")):
|
||||
with patch("dify_trace_tencent.tencent_trace.db") as mock_db:
|
||||
mock_db.init_app = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
with (
|
||||
patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("Database error")),
|
||||
patch("dify_trace_tencent.tencent_trace.db") as mock_db,
|
||||
):
|
||||
mock_db.init_app = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "unknown"
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "unknown"
|
||||
|
||||
def test_get_user_id_only_user_id(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
@ -471,16 +497,18 @@ class TestTencentDataTrace:
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "anonymous"
|
||||
|
||||
def test_get_user_id_exception(self, tencent_data_trace):
|
||||
def test_get_user_id_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.tenant_id = "t"
|
||||
trace_info.metadata = {"user_id": "u"}
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")):
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "unknown"
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID")
|
||||
with (
|
||||
patch("dify_trace_tencent.tencent_trace.sessionmaker", side_effect=Exception("error")),
|
||||
caplog.at_level(logging.ERROR),
|
||||
):
|
||||
user_id = tencent_data_trace._get_user_id(trace_info)
|
||||
assert user_id == "unknown"
|
||||
assert "[Tencent APM] Failed to get user ID" in caplog.text
|
||||
|
||||
def test_record_llm_metrics_usage_in_process_data(self, tencent_data_trace):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
@ -514,14 +542,14 @@ class TestTencentDataTrace:
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
tencent_data_trace.trace_client.record_token_usage.assert_called_once()
|
||||
|
||||
def test_record_llm_metrics_exception(self, tencent_data_trace):
|
||||
def test_record_llm_metrics_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
node = MagicMock(spec=WorkflowNodeExecution)
|
||||
node.process_data = None
|
||||
node.outputs = None
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
tencent_data_trace._record_llm_metrics(node)
|
||||
# Should not crash
|
||||
# Should not crash
|
||||
|
||||
def test_record_message_llm_metrics(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
@ -553,13 +581,13 @@ class TestTencentDataTrace:
|
||||
tencent_data_trace._record_message_llm_metrics(trace_info)
|
||||
tencent_data_trace.trace_client.record_llm_duration.assert_called_once()
|
||||
|
||||
def test_record_message_llm_metrics_exception(self, tencent_data_trace):
|
||||
def test_record_message_llm_metrics_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.metadata = None
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
tencent_data_trace._record_message_llm_metrics(trace_info)
|
||||
# Should not crash
|
||||
# Should not crash
|
||||
|
||||
def test_record_workflow_trace_duration(self, tencent_data_trace):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
@ -605,11 +633,11 @@ class TestTencentDataTrace:
|
||||
attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {}
|
||||
assert attributes["has_conversation"] == "false"
|
||||
|
||||
def test_record_workflow_trace_duration_exception(self, tencent_data_trace):
|
||||
def test_record_workflow_trace_duration_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=WorkflowTraceInfo)
|
||||
trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
tencent_data_trace._record_workflow_trace_duration(trace_info)
|
||||
|
||||
def test_record_message_trace_duration(self, tencent_data_trace):
|
||||
@ -627,11 +655,11 @@ class TestTencentDataTrace:
|
||||
2.0, {"conversation_mode": "chat", "stream": "true"}
|
||||
)
|
||||
|
||||
def test_record_message_trace_duration_exception(self, tencent_data_trace):
|
||||
def test_record_message_trace_duration_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
trace_info = MagicMock(spec=MessageTraceInfo)
|
||||
trace_info.start_time = None
|
||||
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.debug") as mock_log:
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
tencent_data_trace._record_message_trace_duration(trace_info)
|
||||
|
||||
def test_close(self, tencent_data_trace):
|
||||
@ -647,11 +675,11 @@ class TestTencentDataTrace:
|
||||
|
||||
client.shutdown.assert_called_once()
|
||||
|
||||
def test_close_exception(self, tencent_data_trace):
|
||||
def test_close_exception(self, tencent_data_trace, caplog: pytest.LogCaptureFixture):
|
||||
tencent_data_trace.trace_client.shutdown.side_effect = Exception("error")
|
||||
with patch("dify_trace_tencent.tencent_trace.logger.exception") as mock_log:
|
||||
with caplog.at_level(logging.ERROR):
|
||||
tencent_data_trace.close()
|
||||
mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup")
|
||||
assert "[Tencent APM] Failed to shutdown trace client during cleanup" in caplog.text
|
||||
|
||||
def test_close_handles_async_shutdown_mock(self, tencent_data_trace):
|
||||
shutdown = AsyncMock()
|
||||
|
||||
@ -113,7 +113,7 @@ class LindormVectorStore(BaseVector):
|
||||
)
|
||||
def _bulk_with_retry(actions):
|
||||
try:
|
||||
response = self._client.bulk(actions, timeout=timeout)
|
||||
response = self._client.bulk(body=actions, timeout=timeout)
|
||||
if response["errors"]:
|
||||
error_items = [item for item in response["items"] if "error" in item["index"]]
|
||||
error_msg = f"Bulk indexing had {len(error_items)} errors"
|
||||
@ -231,7 +231,7 @@ class LindormVectorStore(BaseVector):
|
||||
routing_filter_query = {
|
||||
"query": {"bool": {"must": [{"term": {f"{ROUTING_FIELD}.keyword": self._routing}}]}}
|
||||
}
|
||||
self._client.delete_by_query(self._collection_name, body=routing_filter_query)
|
||||
self._client.delete_by_query(index=self._collection_name, body=routing_filter_query)
|
||||
self.refresh()
|
||||
else:
|
||||
if self._client.indices.exists(index=self._collection_name):
|
||||
|
||||
@ -127,7 +127,7 @@ def test_create_refresh_and_add_texts_success(lindorm_module, monkeypatch: pytes
|
||||
vector.add_texts(docs, embeddings, batch_size=2, timeout=9)
|
||||
|
||||
assert vector._client.bulk.call_count == 2
|
||||
actions = vector._client.bulk.call_args_list[0].args[0]
|
||||
actions = vector._client.bulk.call_args_list[0].kwargs["body"]
|
||||
assert actions[0]["index"]["routing"] == "route"
|
||||
assert actions[1][lindorm_module.ROUTING_FIELD] == "route"
|
||||
vector.refresh()
|
||||
|
||||
@ -268,9 +268,11 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
wv._client = MagicMock()
|
||||
wv._client.collections.exists.side_effect = RuntimeError("create failed")
|
||||
|
||||
with patch.object(weaviate_vector_module.logger, "exception") as mock_exception:
|
||||
with pytest.raises(RuntimeError, match="create failed"):
|
||||
wv._create_collection()
|
||||
with (
|
||||
patch.object(weaviate_vector_module.logger, "exception") as mock_exception,
|
||||
pytest.raises(RuntimeError, match="create failed"),
|
||||
):
|
||||
wv._create_collection()
|
||||
|
||||
mock_exception.assert_called_once()
|
||||
|
||||
@ -835,9 +837,11 @@ class TestWeaviateVector(unittest.TestCase):
|
||||
wv._client.collections.use.return_value = mock_col
|
||||
mock_col.data.delete_by_id.side_effect = FakeUnexpectedStatusCodeError(500)
|
||||
|
||||
with patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError):
|
||||
with pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"):
|
||||
wv.delete_by_ids(["bad-id"])
|
||||
with (
|
||||
patch.object(weaviate_vector_module, "UnexpectedStatusCodeError", FakeUnexpectedStatusCodeError),
|
||||
pytest.raises(FakeUnexpectedStatusCodeError, match="status=500"),
|
||||
):
|
||||
wv.delete_by_ids(["bad-id"])
|
||||
|
||||
def test_json_serializable_converts_datetime(self):
|
||||
wv = WeaviateVector.__new__(WeaviateVector)
|
||||
|
||||
@ -5,7 +5,7 @@ requires-python = "~=3.12.0"
|
||||
|
||||
dependencies = [
|
||||
# Legacy: mature and widely deployed
|
||||
"bleach>=6.3.0,<7.0.0",
|
||||
"bleach>=6.4.0,<7.0.0",
|
||||
"boto3>=1.43.24,<2.0.0",
|
||||
"celery>=5.6.3,<6.0.0",
|
||||
"croniter>=6.2.2,<7.0.0",
|
||||
@ -44,7 +44,7 @@ dependencies = [
|
||||
"resend>=2.27.0,<3.0.0",
|
||||
# Emerging: newer and fast-moving, use compatible pins
|
||||
"fastopenapi[flask]==0.7.0",
|
||||
"graphon==0.5.2",
|
||||
"graphon==0.5.3",
|
||||
"httpx-sse==0.4.3",
|
||||
"json-repair==0.59.4",
|
||||
]
|
||||
@ -103,7 +103,11 @@ dify-trace-weave = { workspace = true }
|
||||
[tool.uv]
|
||||
default-groups = ["storage", "tools", "vdb-all", "trace-all"]
|
||||
package = false
|
||||
override-dependencies = ["litellm>=1.83.10,<2.0.0", "pyarrow>=23.0.1,<24.0.0"]
|
||||
override-dependencies = [
|
||||
"litellm>=1.83.10,<2.0.0",
|
||||
"pyarrow>=23.0.1,<24.0.0",
|
||||
"cryptography>=49.0.0,<50.0.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
|
||||
@ -1,3 +1,10 @@
|
||||
"""Account, workspace, and invitation services.
|
||||
|
||||
Database access in this module is caller-scoped: methods that read or mutate ORM state accept an explicit
|
||||
``session`` so controllers, tasks, and tests can control transaction lifetime and avoid hidden Flask-scoped session
|
||||
usage inside service logic.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@ -14,7 +21,6 @@ from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import get_valid_language, language_timezone_mapping
|
||||
from core.db.session_factory import session_factory
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
@ -236,7 +242,7 @@ class AccountService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _refresh_account_last_active(account: Account) -> None:
|
||||
def _refresh_account_last_active(account: Account, session: scoped_session | Session) -> None:
|
||||
now = naive_utc_now()
|
||||
refresh_before = now - ACCOUNT_LAST_ACTIVE_REFRESH_INTERVAL
|
||||
|
||||
@ -246,12 +252,12 @@ class AccountService:
|
||||
if not AccountService._should_refresh_account_last_active(account.id):
|
||||
return
|
||||
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
update(Account)
|
||||
.where(Account.id == account.id, Account.last_active_at < refresh_before)
|
||||
.values(last_active_at=now, updated_at=func.current_timestamp())
|
||||
)
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def _store_refresh_token(refresh_token: str, account_id: str):
|
||||
@ -296,20 +302,20 @@ class AccountService:
|
||||
side-effects (current-tenant assignment, commit) are unwanted.
|
||||
|
||||
``session`` is injected by the caller so this service stays free
|
||||
of the Flask-scoped ``db.session`` import.
|
||||
of a Flask-scoped session import.
|
||||
"""
|
||||
return session.get(Account, account_id)
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> None | Account:
|
||||
account = db.session.get(Account, user_id)
|
||||
def load_user(user_id: str, session: scoped_session | Session) -> None | Account:
|
||||
account = session.get(Account, user_id)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if account.status == AccountStatus.BANNED:
|
||||
raise Unauthorized("Account is banned.")
|
||||
|
||||
current_tenant = db.session.scalar(
|
||||
current_tenant = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.current == True)
|
||||
.limit(1)
|
||||
@ -317,7 +323,7 @@ class AccountService:
|
||||
if current_tenant:
|
||||
account.set_tenant_id(current_tenant.tenant_id)
|
||||
else:
|
||||
available_ta = db.session.scalar(
|
||||
available_ta = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
@ -329,13 +335,13 @@ class AccountService:
|
||||
account.set_tenant_id(available_ta.tenant_id)
|
||||
available_ta.current = True
|
||||
available_ta.last_opened_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
AccountService._refresh_account_last_active(account)
|
||||
AccountService._refresh_account_last_active(account, session)
|
||||
# NOTE: make sure account is accessible outside of a db session
|
||||
# This ensures that it will work correctly after upgrading to Flask version 3.1.2
|
||||
db.session.refresh(account)
|
||||
db.session.close()
|
||||
session.refresh(account)
|
||||
session.close()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
@ -353,10 +359,12 @@ class AccountService:
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
|
||||
def authenticate(
|
||||
email: str, password: str, invite_token: str | None = None, *, session: scoped_session | Session
|
||||
) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
account = session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if not account:
|
||||
raise AccountPasswordError("Invalid email or password.")
|
||||
|
||||
@ -379,12 +387,14 @@ class AccountService:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_account_password(account, password, new_password):
|
||||
def update_account_password(
|
||||
account: Account, password: str, new_password: str, *, session: scoped_session | Session
|
||||
):
|
||||
"""update account password"""
|
||||
if account.password and not compare_password(password, account.password, account.password_salt):
|
||||
raise CurrentPasswordIncorrectError("Current password is incorrect.")
|
||||
@ -401,8 +411,8 @@ class AccountService:
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
session.add(account)
|
||||
session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
@ -414,6 +424,8 @@ class AccountService:
|
||||
interface_theme: str = "light",
|
||||
is_setup: bool | None = False,
|
||||
timezone: str | None = None,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
) -> Account:
|
||||
"""Create an account, preferring explicit user timezone over language-derived defaults."""
|
||||
if not FeatureService.get_system_features().is_allow_register and not is_setup:
|
||||
@ -459,13 +471,19 @@ class AccountService:
|
||||
timezone=resolved_timezone,
|
||||
)
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
session.add(account)
|
||||
session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_account_and_tenant(
|
||||
email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None
|
||||
email: str,
|
||||
name: str,
|
||||
interface_language: str,
|
||||
password: str | None = None,
|
||||
timezone: str | None = None,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
) -> Account:
|
||||
"""Create an account and owner workspace."""
|
||||
account = AccountService.create_account(
|
||||
@ -474,10 +492,11 @@ class AccountService:
|
||||
interface_language=interface_language,
|
||||
password=password,
|
||||
timezone=timezone,
|
||||
session=session,
|
||||
)
|
||||
|
||||
try:
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account)
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account, session=session)
|
||||
except Exception:
|
||||
# Enterprise-only side-effect should run independently from personal workspace creation.
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
@ -537,11 +556,11 @@ class AccountService:
|
||||
delete_account_task.delay(account.id)
|
||||
|
||||
@staticmethod
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account):
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account, *, session: scoped_session | Session):
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: AccountIntegrate | None = db.session.scalar(
|
||||
account_integrate: AccountIntegrate | None = session.scalar(
|
||||
select(AccountIntegrate)
|
||||
.where(AccountIntegrate.account_id == account.id, AccountIntegrate.provider == provider)
|
||||
.limit(1)
|
||||
@ -557,62 +576,62 @@ class AccountService:
|
||||
account_integrate = AccountIntegrate(
|
||||
account_id=account.id, provider=provider, open_id=open_id, encrypted_token=""
|
||||
)
|
||||
db.session.add(account_integrate)
|
||||
session.add(account_integrate)
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
logger.info("Account %s linked %s account %s.", account.id, provider, open_id)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to link %s account %s to Account %s", provider, open_id, account.id)
|
||||
raise LinkAccountIntegrateError("Failed to link account.") from e
|
||||
|
||||
@staticmethod
|
||||
def close_account(account: Account):
|
||||
def close_account(account: Account, *, session: scoped_session | Session):
|
||||
"""Close account"""
|
||||
account.status = AccountStatus.CLOSED
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_account(account, **kwargs):
|
||||
def update_account(account: Account, *, session: scoped_session | Session, **kwargs):
|
||||
"""Update account fields"""
|
||||
account = db.session.merge(account)
|
||||
account = session.merge(account)
|
||||
for field, value in kwargs.items():
|
||||
if hasattr(account, field):
|
||||
setattr(account, field, value)
|
||||
else:
|
||||
raise AttributeError(f"Invalid field: {field}")
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_account_email(account: Account, email: str) -> Account:
|
||||
def update_account_email(account: Account, email: str, session: scoped_session | Session) -> Account:
|
||||
"""Update account email"""
|
||||
account.email = email
|
||||
account_integrate = db.session.scalar(
|
||||
account_integrate = session.scalar(
|
||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id).limit(1)
|
||||
)
|
||||
if account_integrate:
|
||||
db.session.delete(account_integrate)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
session.delete(account_integrate)
|
||||
session.add(account)
|
||||
session.commit()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_login_info(account: Account, *, ip_address: str):
|
||||
def update_login_info(account: Account, session: scoped_session | Session, *, ip_address: str):
|
||||
"""Update last login time and ip"""
|
||||
account.last_login_at = naive_utc_now()
|
||||
account.last_login_ip = ip_address
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def login(account: Account, *, ip_address: str | None = None) -> TokenPair:
|
||||
def login(account: Account, *, session: scoped_session | Session, ip_address: str | None = None) -> TokenPair:
|
||||
if ip_address:
|
||||
AccountService.update_login_info(account=account, ip_address=ip_address)
|
||||
AccountService.update_login_info(account=account, session=session, ip_address=ip_address)
|
||||
|
||||
if account.status == AccountStatus.PENDING:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
access_token = AccountService.get_account_jwt_token(account=account)
|
||||
refresh_token = _generate_refresh_token()
|
||||
@ -629,13 +648,13 @@ class AccountService:
|
||||
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
|
||||
|
||||
@staticmethod
|
||||
def refresh_token(refresh_token: str) -> TokenPair:
|
||||
def refresh_token(refresh_token: str, *, session: scoped_session | Session) -> TokenPair:
|
||||
# Verify the refresh token
|
||||
account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
|
||||
if not account_id:
|
||||
raise ValueError("Invalid refresh token")
|
||||
|
||||
account = AccountService.load_user(account_id.decode("utf-8"))
|
||||
account = AccountService.load_user(account_id.decode("utf-8"), session)
|
||||
if not account:
|
||||
raise ValueError("Invalid account")
|
||||
|
||||
@ -650,8 +669,8 @@ class AccountService:
|
||||
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token)
|
||||
|
||||
@staticmethod
|
||||
def load_logged_in_account(*, account_id: str):
|
||||
return AccountService.load_user(account_id)
|
||||
def load_logged_in_account(*, account_id: str, session: scoped_session | Session):
|
||||
return AccountService.load_user(account_id, session)
|
||||
|
||||
@classmethod
|
||||
def send_reset_password_email(
|
||||
@ -981,19 +1000,18 @@ class AccountService:
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def get_account_by_email_with_case_fallback(email: str) -> Account | None:
|
||||
def get_account_by_email_with_case_fallback(session: Session | scoped_session, email: str) -> Account | None:
|
||||
"""
|
||||
Retrieve an account by email and fall back to the lowercase email if the original lookup fails.
|
||||
|
||||
This keeps backward compatibility for older records that stored uppercase emails while the
|
||||
rest of the system gradually normalizes new inputs.
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
account = session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
|
||||
if account or email == email.lower():
|
||||
return account
|
||||
|
||||
return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none()
|
||||
return session.execute(select(Account).where(Account.email == email.lower())).scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
|
||||
@ -1004,7 +1022,7 @@ class AccountService:
|
||||
TokenManager.revoke_token(token, "email_code_login")
|
||||
|
||||
@classmethod
|
||||
def get_user_through_email(cls, email: str):
|
||||
def get_user_through_email(cls, email: str, *, session: scoped_session | Session):
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(email):
|
||||
raise AccountRegisterError(
|
||||
description=(
|
||||
@ -1013,7 +1031,7 @@ class AccountService:
|
||||
)
|
||||
)
|
||||
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
account = session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if not account:
|
||||
return None
|
||||
|
||||
@ -1212,13 +1230,19 @@ class AccountService:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def check_email_unique(email: str) -> bool:
|
||||
return db.session.scalar(select(Account).where(Account.email == email).limit(1)) is None
|
||||
def check_email_unique(email: str, *, session: scoped_session | Session) -> bool:
|
||||
return session.scalar(select(Account).where(Account.email == email).limit(1)) is None
|
||||
|
||||
|
||||
class TenantService:
|
||||
@staticmethod
|
||||
def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant:
|
||||
def create_tenant(
|
||||
name: str,
|
||||
is_setup: bool | None = False,
|
||||
is_from_dashboard: bool | None = False,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
) -> Tenant:
|
||||
"""Create tenant"""
|
||||
if (
|
||||
not FeatureService.get_system_features().is_allow_create_workspace
|
||||
@ -1230,8 +1254,8 @@ class TenantService:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
tenant = Tenant(name=name)
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
session.add(tenant)
|
||||
session.commit()
|
||||
|
||||
for category in TenantPluginAutoUpgradeStrategy.PluginCategory:
|
||||
plugin_upgrade_strategy = TenantPluginAutoUpgradeStrategy(
|
||||
@ -1243,11 +1267,11 @@ class TenantService:
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
db.session.add(plugin_upgrade_strategy)
|
||||
db.session.commit()
|
||||
session.add(plugin_upgrade_strategy)
|
||||
session.commit()
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
@ -1256,9 +1280,11 @@ class TenantService:
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
|
||||
def create_owner_tenant_if_not_exist(
|
||||
account: Account, name: str | None = None, is_setup: bool | None = False, *, session: scoped_session | Session
|
||||
):
|
||||
"""Check if user have a workspace or not"""
|
||||
available_ta = db.session.scalar(
|
||||
available_ta = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id)
|
||||
.order_by(TenantAccountJoin.id.asc())
|
||||
@ -1277,10 +1303,10 @@ class TenantService:
|
||||
raise WorkspacesLimitExceededError()
|
||||
|
||||
if name:
|
||||
tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
|
||||
tenant = TenantService.create_tenant(name=name, is_setup=is_setup, session=session)
|
||||
else:
|
||||
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
|
||||
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
|
||||
tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup, session=session)
|
||||
TenantService.create_tenant_member(tenant, account, session, role="owner")
|
||||
if dify_config.RBAC_ENABLED:
|
||||
owner_role_id = AccountService._resolve_legacy_role_id(str(tenant.id), account.id, TenantAccountRole.OWNER)
|
||||
RBACService.MemberRoles.replace(
|
||||
@ -1290,16 +1316,16 @@ class TenantService:
|
||||
role_ids=[owner_role_id],
|
||||
)
|
||||
account.current_tenant = tenant
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
@staticmethod
|
||||
def create_tenant_member(
|
||||
tenant: Tenant, account: Account, session: scoped_session, role: str = "normal"
|
||||
tenant: Tenant, account: Account, session: scoped_session | Session, role: str = "normal"
|
||||
) -> TenantAccountJoin:
|
||||
"""Create tenant member"""
|
||||
if role == TenantAccountRole.OWNER:
|
||||
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
|
||||
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER], session=session):
|
||||
logger.error("Tenant %s has already an owner.", tenant.id)
|
||||
raise Exception("Tenant already has an owner.")
|
||||
|
||||
@ -1320,10 +1346,10 @@ class TenantService:
|
||||
return ta
|
||||
|
||||
@staticmethod
|
||||
def get_join_tenants(account: Account) -> list[Tenant]:
|
||||
def get_join_tenants(account: Account, *, session: scoped_session | Session) -> list[Tenant]:
|
||||
"""Get account join tenants"""
|
||||
return list(
|
||||
db.session.scalars(
|
||||
session.scalars(
|
||||
select(Tenant)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
|
||||
@ -1342,7 +1368,7 @@ class TenantService:
|
||||
membership + pick the default workspace.
|
||||
|
||||
``session`` is injected by the caller so this service stays free
|
||||
of the Flask-scoped ``db.session`` import.
|
||||
of a Flask-scoped session import.
|
||||
|
||||
No tenant-status filter: parity with the legacy controller query
|
||||
(the openapi identity endpoint listed all joined tenants).
|
||||
@ -1415,7 +1441,7 @@ class TenantService:
|
||||
bearers (no account) collapse to the non-member path. Mirrors the
|
||||
session-injection style of :meth:`account_belongs_to_tenant` rather
|
||||
than :meth:`get_user_role`, which loads full ``Account``/``Tenant``
|
||||
objects against the Flask-scoped ``db.session``.
|
||||
objects against the Flask-scoped session.
|
||||
"""
|
||||
if not account_id:
|
||||
return None
|
||||
@ -1481,13 +1507,13 @@ class TenantService:
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def get_current_tenant_by_account(account: Account):
|
||||
def get_current_tenant_by_account(account: Account, *, session: scoped_session | Session):
|
||||
"""Get tenant by account and add the role"""
|
||||
tenant = account.current_tenant
|
||||
if not tenant:
|
||||
raise TenantNotFoundError("Tenant not found.")
|
||||
|
||||
ta = db.session.scalar(
|
||||
ta = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
@ -1499,14 +1525,14 @@ class TenantService:
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def switch_tenant(account: Account, tenant_id: str | None = None):
|
||||
def switch_tenant(account: Account, tenant_id: str | None = None, *, session: scoped_session | Session):
|
||||
"""Switch the current workspace for the account"""
|
||||
|
||||
# Ensure tenant_id is provided
|
||||
if tenant_id is None:
|
||||
raise ValueError("Tenant ID must be provided.")
|
||||
|
||||
tenant_account_join = db.session.scalar(
|
||||
tenant_account_join = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(
|
||||
@ -1520,7 +1546,7 @@ class TenantService:
|
||||
if not tenant_account_join:
|
||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||
else:
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
update(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id)
|
||||
.values(current=False)
|
||||
@ -1529,10 +1555,10 @@ class TenantService:
|
||||
tenant_account_join.last_opened_at = naive_utc_now()
|
||||
# Set the current tenant for the account
|
||||
account.set_tenant_id(tenant_account_join.tenant_id)
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_members(tenant: Tenant) -> list[Account]:
|
||||
def get_tenant_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]:
|
||||
"""Get tenant members"""
|
||||
stmt = (
|
||||
select(Account, TenantAccountJoin.role)
|
||||
@ -1544,14 +1570,14 @@ class TenantService:
|
||||
# Initialize an empty list to store the updated accounts
|
||||
updated_accounts = []
|
||||
|
||||
for account, role in db.session.execute(stmt):
|
||||
for account, role in session.execute(stmt):
|
||||
account.role = role
|
||||
updated_accounts.append(account)
|
||||
|
||||
return updated_accounts
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_operator_members(tenant: Tenant) -> list[Account]:
|
||||
def get_dataset_operator_members(tenant: Tenant, *, session: scoped_session | Session) -> list[Account]:
|
||||
"""Get dataset admin members"""
|
||||
stmt = (
|
||||
select(Account, TenantAccountJoin.role)
|
||||
@ -1564,20 +1590,20 @@ class TenantService:
|
||||
# Initialize an empty list to store the updated accounts
|
||||
updated_accounts = []
|
||||
|
||||
for account, role in db.session.execute(stmt):
|
||||
for account, role in session.execute(stmt):
|
||||
account.role = role
|
||||
updated_accounts.append(account)
|
||||
|
||||
return updated_accounts
|
||||
|
||||
@staticmethod
|
||||
def has_roles(tenant: Tenant, roles: list[TenantAccountRole]) -> bool:
|
||||
def has_roles(tenant: Tenant, roles: list[TenantAccountRole], *, session: scoped_session | Session) -> bool:
|
||||
"""Check if user has any of the given roles for a tenant"""
|
||||
if not all(isinstance(role, TenantAccountRole) for role in roles):
|
||||
raise ValueError("all roles must be TenantAccountRole")
|
||||
|
||||
return (
|
||||
db.session.scalar(
|
||||
session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
@ -1589,9 +1615,11 @@ class TenantService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
|
||||
def get_user_role(
|
||||
account: Account, tenant: Tenant, *, session: scoped_session | Session
|
||||
) -> TenantAccountRole | None:
|
||||
"""Get the role of the current account for a given tenant"""
|
||||
join = db.session.scalar(
|
||||
join = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
@ -1599,12 +1627,14 @@ class TenantService:
|
||||
return TenantAccountRole(join.role) if join else None
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_count() -> int:
|
||||
def get_tenant_count(*, session: scoped_session | Session) -> int:
|
||||
"""Get tenant count"""
|
||||
return cast(int, db.session.scalar(select(func.count(Tenant.id))))
|
||||
return cast(int, session.scalar(select(func.count(Tenant.id))))
|
||||
|
||||
@staticmethod
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
|
||||
def check_member_permission(
|
||||
tenant: Tenant, operator: Account, member: Account | None, action: str, *, session: scoped_session | Session
|
||||
):
|
||||
"""Check member permission"""
|
||||
if action not in {"add", "remove", "update"}:
|
||||
raise InvalidActionError("Invalid action.")
|
||||
@ -1638,7 +1668,7 @@ class TenantService:
|
||||
"update": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
||||
}
|
||||
|
||||
ta_operator = db.session.scalar(
|
||||
ta_operator = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == operator.id)
|
||||
.limit(1)
|
||||
@ -1648,7 +1678,7 @@ class TenantService:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
|
||||
if action == "remove" and ta_operator.role == TenantAccountRole.ADMIN and member:
|
||||
ta_member = db.session.scalar(
|
||||
ta_member = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
|
||||
.limit(1)
|
||||
@ -1657,7 +1687,9 @@ class TenantService:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
|
||||
def remove_member_from_tenant(
|
||||
tenant: Tenant, account: Account, operator: Account, *, session: scoped_session | Session
|
||||
):
|
||||
"""Remove member from tenant.
|
||||
|
||||
Apps and datasets maintained by the removed member are reassigned to
|
||||
@ -1669,9 +1701,9 @@ class TenantService:
|
||||
if operator.id == account.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
TenantService.check_member_permission(tenant, operator, account, "remove")
|
||||
TenantService.check_member_permission(tenant, operator, account, "remove", session=session)
|
||||
|
||||
ta = db.session.scalar(
|
||||
ta = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
@ -1688,7 +1720,7 @@ class TenantService:
|
||||
if dify_config.RBAC_ENABLED:
|
||||
owner_id = AccountService.get_rbac_workspace_owner_account_id(str(tenant.id), str(operator.id))
|
||||
else:
|
||||
owner_id = db.session.scalar(
|
||||
owner_id = session.scalar(
|
||||
select(TenantAccountJoin.account_id)
|
||||
.where(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
@ -1699,7 +1731,7 @@ class TenantService:
|
||||
if owner_id is None:
|
||||
raise ValueError(f"Workspace owner not found for tenant {tenant.id}.")
|
||||
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
update(App)
|
||||
.where(
|
||||
App.tenant_id == tenant.id,
|
||||
@ -1707,7 +1739,7 @@ class TenantService:
|
||||
)
|
||||
.values(maintainer=owner_id)
|
||||
)
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
update(Dataset)
|
||||
.where(
|
||||
Dataset.tenant_id == tenant.id,
|
||||
@ -1715,23 +1747,23 @@ class TenantService:
|
||||
)
|
||||
.values(maintainer=owner_id)
|
||||
)
|
||||
db.session.delete(ta)
|
||||
session.delete(ta)
|
||||
|
||||
# Clean up orphaned pending accounts (invited but never activated)
|
||||
should_delete_account = False
|
||||
if account.status == AccountStatus.PENDING:
|
||||
# autoflush flushes ta deletion before this query, so 0 means no remaining joins
|
||||
remaining_joins = (
|
||||
db.session.scalar(
|
||||
session.scalar(
|
||||
select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.account_id == account_id)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if remaining_joins == 0:
|
||||
db.session.delete(account)
|
||||
session.delete(account)
|
||||
should_delete_account = True
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
if should_delete_account:
|
||||
logger.info(
|
||||
@ -1756,13 +1788,18 @@ class TenantService:
|
||||
account_id,
|
||||
)
|
||||
|
||||
if dify_config.RBAC_ENABLED:
|
||||
RBACService.MemberRoles.delete_rbac_bindings(tenant_id=tenant.id, account_id=account_id)
|
||||
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
|
||||
def update_member_role(
|
||||
tenant: Tenant, member: Account, new_role: str, operator: Account, *, session: scoped_session | Session
|
||||
):
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
TenantService.check_member_permission(tenant, operator, member, "update", session=session)
|
||||
new_tenant_role = TenantAccountRole(new_role)
|
||||
|
||||
target_member_join = db.session.scalar(
|
||||
target_member_join = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == member.id)
|
||||
.limit(1)
|
||||
@ -1771,7 +1808,7 @@ class TenantService:
|
||||
if not target_member_join:
|
||||
raise MemberNotInTenantError("Member not in tenant.")
|
||||
|
||||
operator_role = TenantService.get_user_role(operator, tenant)
|
||||
operator_role = TenantService.get_user_role(operator, tenant, session=session)
|
||||
target_role = TenantAccountRole(target_member_join.role)
|
||||
if operator_role == TenantAccountRole.ADMIN and (TenantAccountRole.OWNER in {target_role, new_tenant_role}):
|
||||
raise NoPermissionError("No permission to update member.")
|
||||
@ -1781,7 +1818,7 @@ class TenantService:
|
||||
|
||||
if new_role == "owner":
|
||||
# Find the current owner and change their role to 'admin'
|
||||
current_owner_join = db.session.scalar(
|
||||
current_owner_join = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner")
|
||||
.limit(1)
|
||||
@ -1817,7 +1854,7 @@ class TenantService:
|
||||
)
|
||||
else:
|
||||
target_member_join.role = new_tenant_role
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_custom_config(tenant_id: str):
|
||||
@ -1826,13 +1863,13 @@ class TenantService:
|
||||
return tenant.custom_config_dict
|
||||
|
||||
@staticmethod
|
||||
def is_owner(account: Account, tenant: Tenant) -> bool:
|
||||
return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER
|
||||
def is_owner(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool:
|
||||
return TenantService.get_user_role(account, tenant, session=session) == TenantAccountRole.OWNER
|
||||
|
||||
@staticmethod
|
||||
def is_member(account: Account, tenant: Tenant) -> bool:
|
||||
def is_member(account: Account, tenant: Tenant, *, session: scoped_session | Session) -> bool:
|
||||
"""Check if the account is a member of the tenant"""
|
||||
return TenantService.get_user_role(account, tenant) is not None
|
||||
return TenantService.get_user_role(account, tenant, session=session) is not None
|
||||
|
||||
|
||||
class RegisterService:
|
||||
@ -1841,7 +1878,16 @@ class RegisterService:
|
||||
return f"member_invite:token:{token}"
|
||||
|
||||
@classmethod
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str, language: str | None):
|
||||
def setup(
|
||||
cls,
|
||||
email: str,
|
||||
name: str,
|
||||
password: str,
|
||||
ip_address: str,
|
||||
language: str | None,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
):
|
||||
"""
|
||||
Setup dify
|
||||
|
||||
@ -1858,22 +1904,23 @@ class RegisterService:
|
||||
interface_language=get_valid_language(language),
|
||||
password=password,
|
||||
is_setup=True,
|
||||
session=session,
|
||||
)
|
||||
|
||||
account.last_login_ip = ip_address
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True, session=session)
|
||||
|
||||
dify_setup = DifySetup(version=dify_config.project.version)
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
session.add(dify_setup)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
db.session.execute(delete(DifySetup))
|
||||
db.session.execute(delete(TenantAccountJoin))
|
||||
db.session.execute(delete(Account))
|
||||
db.session.execute(delete(Tenant))
|
||||
db.session.commit()
|
||||
session.execute(delete(DifySetup))
|
||||
session.execute(delete(TenantAccountJoin))
|
||||
session.execute(delete(Account))
|
||||
session.execute(delete(Tenant))
|
||||
session.commit()
|
||||
|
||||
logger.exception("Setup account failed, email: %s, name: %s", email, name)
|
||||
raise ValueError(f"Setup failed: {e}")
|
||||
@ -1891,9 +1938,11 @@ class RegisterService:
|
||||
is_setup: bool | None = False,
|
||||
create_workspace_required: bool | None = True,
|
||||
timezone: str | None = None,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
) -> Account:
|
||||
"""Register account"""
|
||||
db.session.begin_nested()
|
||||
session.begin_nested()
|
||||
try:
|
||||
interface_language = get_valid_language(language)
|
||||
account = AccountService.create_account(
|
||||
@ -1903,12 +1952,13 @@ class RegisterService:
|
||||
password=password,
|
||||
is_setup=is_setup,
|
||||
timezone=timezone,
|
||||
session=session,
|
||||
)
|
||||
account.status = status or AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
if open_id is not None and provider is not None:
|
||||
AccountService.link_account_integrate(provider, open_id, account)
|
||||
AccountService.link_account_integrate(provider, open_id, account, session=session)
|
||||
|
||||
if (
|
||||
FeatureService.get_system_features().is_allow_create_workspace
|
||||
@ -1916,27 +1966,27 @@ class RegisterService:
|
||||
and FeatureService.get_system_features().license.workspaces.is_available()
|
||||
):
|
||||
try:
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, db.session, role="owner")
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace", session=session)
|
||||
TenantService.create_tenant_member(tenant, account, session, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
except Exception:
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
raise
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
_try_join_enterprise_default_workspace(str(account.id))
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
db.session.rollback()
|
||||
session.rollback()
|
||||
logger.exception("Register failed")
|
||||
raise AccountRegisterError("Workspace is not allowed to create.")
|
||||
except AccountRegisterError as are:
|
||||
db.session.rollback()
|
||||
session.rollback()
|
||||
logger.exception("Register failed")
|
||||
raise are
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
session.rollback()
|
||||
logger.exception("Register failed")
|
||||
raise AccountRegisterError(f"Registration failed: {e}") from e
|
||||
|
||||
@ -1944,7 +1994,14 @@ class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def invite_new_member(
|
||||
cls, tenant: Tenant, email: str, language: str | None, role: str = "normal", inviter: Account | None = None
|
||||
cls,
|
||||
tenant: Tenant,
|
||||
email: str,
|
||||
language: str | None,
|
||||
role: str = "normal",
|
||||
inviter: Account | None = None,
|
||||
*,
|
||||
session: scoped_session | Session,
|
||||
) -> str:
|
||||
if not inviter:
|
||||
raise ValueError("Inviter is required")
|
||||
@ -1958,11 +2015,11 @@ class RegisterService:
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
|
||||
requires_setup = False
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add", session=session)
|
||||
name = normalized_email.split("@")[0]
|
||||
|
||||
account = cls.register(
|
||||
@ -1971,13 +2028,14 @@ class RegisterService:
|
||||
language=language,
|
||||
status=AccountStatus.PENDING,
|
||||
is_setup=True,
|
||||
session=session,
|
||||
)
|
||||
TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role)
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
TenantService.create_tenant_member(tenant, account, session, tenant_join_role)
|
||||
TenantService.switch_tenant(account, tenant.id, session=session)
|
||||
requires_setup = True
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = db.session.scalar(
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add", session=session)
|
||||
ta = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
@ -1985,7 +2043,7 @@ class RegisterService:
|
||||
requires_setup = account.status == AccountStatus.PENDING
|
||||
|
||||
if not ta and (account.status == AccountStatus.PENDING or dify_config.RBAC_ENABLED):
|
||||
TenantService.create_tenant_member(tenant, account, db.session, tenant_join_role)
|
||||
TenantService.create_tenant_member(tenant, account, session, tenant_join_role)
|
||||
|
||||
# Support resend invitation email when the account is pending status
|
||||
if account.status != AccountStatus.PENDING:
|
||||
@ -2054,20 +2112,20 @@ class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def get_invitation_if_token_valid(
|
||||
cls, workspace_id: str | None, email: str | None, token: str
|
||||
cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session
|
||||
) -> InvitationDetailDict | None:
|
||||
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
||||
tenant = db.session.scalar(
|
||||
tenant = session.scalar(
|
||||
select(Tenant).where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
account = db.session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1))
|
||||
account = session.scalar(select(Account).where(Account.email == invitation_data["email"]).limit(1))
|
||||
if not account:
|
||||
return None
|
||||
|
||||
@ -2107,13 +2165,13 @@ class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def get_invitation_with_case_fallback(
|
||||
cls, workspace_id: str | None, email: str | None, token: str
|
||||
cls, workspace_id: str | None, email: str | None, token: str, *, session: scoped_session | Session
|
||||
) -> InvitationDetailDict | None:
|
||||
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token)
|
||||
invitation = cls.get_invitation_if_token_valid(workspace_id, email, token, session=session)
|
||||
if invitation or not email or email == email.lower():
|
||||
return invitation
|
||||
normalized_email = email.lower()
|
||||
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token)
|
||||
return cls.get_invitation_if_token_valid(workspace_id, normalized_email, token, session=session)
|
||||
|
||||
|
||||
def _generate_refresh_token(length: int = 64):
|
||||
|
||||
@ -4,15 +4,18 @@ from typing import Any
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
from models import Account
|
||||
from models.agent import (
|
||||
Agent,
|
||||
AgentConfigRevision,
|
||||
AgentConfigRevisionOperation,
|
||||
AgentConfigSnapshot,
|
||||
AgentDriveFile,
|
||||
AgentIconType,
|
||||
AgentKind,
|
||||
AgentScope,
|
||||
AgentSource,
|
||||
@ -20,9 +23,7 @@ from models.agent import (
|
||||
WorkflowAgentBindingType,
|
||||
WorkflowAgentNodeBinding,
|
||||
)
|
||||
from models.agent_config_entities import (
|
||||
DeclaredOutputConfig,
|
||||
)
|
||||
from models.agent_config_entities import DeclaredOutputConfig
|
||||
from models.agent_config_entities import (
|
||||
effective_declared_outputs as _effective_declared_outputs,
|
||||
)
|
||||
@ -32,8 +33,12 @@ from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.agent.errors import (
|
||||
AgentNameConflictError,
|
||||
AgentNotFoundError,
|
||||
AgentVersionConflictError,
|
||||
AgentVersionNotFoundError,
|
||||
InvalidComposerConfigError,
|
||||
)
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
from services.entities.agent_entities import (
|
||||
AgentSoulConfig,
|
||||
ComposerCandidatesResponse,
|
||||
@ -172,6 +177,86 @@ class AgentComposerService:
|
||||
)
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def copy_workflow_composer_from_roster(
|
||||
cls,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
node_id: str,
|
||||
account_id: str,
|
||||
source_agent_id: str,
|
||||
source_snapshot_id: str | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
workflow = cls._get_draft_workflow(tenant_id=tenant_id, app_id=app_id)
|
||||
binding = cls._require_binding(
|
||||
cls._get_workflow_binding(tenant_id=tenant_id, workflow_id=workflow.id, node_id=node_id)
|
||||
)
|
||||
|
||||
if binding.binding_type == WorkflowAgentBindingType.INLINE_AGENT and idempotency_key:
|
||||
agent = cls._get_agent_if_present(tenant_id=tenant_id, agent_id=binding.agent_id)
|
||||
version = cls._get_version_if_present(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent.id if agent else None,
|
||||
version_id=binding.current_snapshot_id,
|
||||
)
|
||||
return cls._serialize_workflow_state(binding=binding, agent=agent, version=version)
|
||||
|
||||
if binding.binding_type != WorkflowAgentBindingType.ROSTER_AGENT:
|
||||
raise InvalidComposerConfigError("Workflow agent node must be bound to a roster agent.")
|
||||
if binding.agent_id != source_agent_id:
|
||||
raise InvalidComposerConfigError("Source agent does not match the current workflow node binding.")
|
||||
|
||||
source_agent = cls._require_agent(tenant_id=tenant_id, agent_id=source_agent_id)
|
||||
if source_agent.scope != AgentScope.ROSTER or source_agent.status != AgentStatus.ACTIVE:
|
||||
raise InvalidComposerConfigError("Source agent must be an active roster agent.")
|
||||
source_version = cls._require_version(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=source_agent.id,
|
||||
version_id=source_agent.active_config_snapshot_id,
|
||||
)
|
||||
if source_snapshot_id and source_snapshot_id != source_version.id:
|
||||
raise AgentVersionConflictError()
|
||||
|
||||
agent_soul = AgentSoulConfig.model_validate(source_version.config_snapshot_dict)
|
||||
inline_agent = cls._create_workflow_only_agent(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow.id,
|
||||
node_id=node_id,
|
||||
account_id=account_id,
|
||||
agent_soul=agent_soul,
|
||||
name=source_agent.name,
|
||||
description=source_agent.description,
|
||||
role=source_agent.role,
|
||||
icon_type=source_agent.icon_type,
|
||||
icon=source_agent.icon,
|
||||
icon_background=source_agent.icon_background,
|
||||
)
|
||||
cls._copy_agent_drive_rows(
|
||||
tenant_id=tenant_id,
|
||||
source_agent_id=source_agent.id,
|
||||
target_agent_id=inline_agent.id,
|
||||
account_id=account_id,
|
||||
agent_soul=agent_soul,
|
||||
node_job=WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict),
|
||||
)
|
||||
|
||||
binding.binding_type = WorkflowAgentBindingType.INLINE_AGENT
|
||||
binding.agent_id = inline_agent.id
|
||||
binding.current_snapshot_id = inline_agent.active_config_snapshot_id
|
||||
binding.updated_by = account_id
|
||||
db.session.flush()
|
||||
db.session.commit()
|
||||
|
||||
version = cls._require_version(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=inline_agent.id,
|
||||
version_id=inline_agent.active_config_snapshot_id,
|
||||
)
|
||||
return cls._serialize_workflow_state(binding=binding, agent=inline_agent, version=version)
|
||||
|
||||
@classmethod
|
||||
def load_agent_app_composer(cls, *, tenant_id: str, app_id: str) -> dict[str, Any]:
|
||||
agent = db.session.scalar(
|
||||
@ -849,6 +934,11 @@ class AgentComposerService:
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
name=agent_name,
|
||||
description=payload.description or "",
|
||||
role=payload.role or "",
|
||||
icon_type=payload.icon_type,
|
||||
icon=payload.icon,
|
||||
icon_background=payload.icon_background,
|
||||
agent_soul=payload.agent_soul,
|
||||
operation=AgentConfigRevisionOperation.SAVE_NEW_AGENT,
|
||||
version_note=payload.version_note,
|
||||
@ -894,10 +984,25 @@ class AgentComposerService:
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
name=agent_name,
|
||||
description=payload.description if payload.description is not None else source_agent.description,
|
||||
role=payload.role if payload.role is not None else source_agent.role,
|
||||
icon_type=payload.icon_type if payload.icon_type is not None else source_agent.icon_type,
|
||||
icon=payload.icon if payload.icon is not None else source_agent.icon,
|
||||
icon_background=payload.icon_background
|
||||
if payload.icon_background is not None
|
||||
else source_agent.icon_background,
|
||||
agent_soul=agent_soul,
|
||||
operation=AgentConfigRevisionOperation.SAVE_TO_ROSTER,
|
||||
version_note=payload.version_note,
|
||||
)
|
||||
cls._copy_agent_drive_rows(
|
||||
tenant_id=tenant_id,
|
||||
source_agent_id=source_agent.id,
|
||||
target_agent_id=roster_agent.id,
|
||||
account_id=account_id,
|
||||
agent_soul=agent_soul,
|
||||
node_job=payload.node_job or WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict),
|
||||
)
|
||||
binding.binding_type = WorkflowAgentBindingType.ROSTER_AGENT
|
||||
binding.agent_id = roster_agent.id
|
||||
binding.current_snapshot_id = roster_agent.active_config_snapshot_id
|
||||
@ -916,11 +1021,21 @@ class AgentComposerService:
|
||||
node_id: str,
|
||||
account_id: str,
|
||||
agent_soul: AgentSoulConfig,
|
||||
name: str | None = None,
|
||||
description: str = "",
|
||||
role: str = "",
|
||||
icon_type: Any | None = None,
|
||||
icon: str | None = None,
|
||||
icon_background: str | None = None,
|
||||
) -> Agent:
|
||||
agent = Agent(
|
||||
tenant_id=tenant_id,
|
||||
name=f"Workflow Agent {node_id}",
|
||||
description="",
|
||||
name=name or f"Workflow Agent {node_id}",
|
||||
description=description,
|
||||
role=role,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background,
|
||||
agent_kind=AgentKind.DIFY_AGENT,
|
||||
scope=AgentScope.WORKFLOW_ONLY,
|
||||
source=AgentSource.WORKFLOW,
|
||||
@ -945,6 +1060,98 @@ class AgentComposerService:
|
||||
agent.active_config_has_model = agent_soul_has_model(agent_soul)
|
||||
return agent
|
||||
|
||||
@classmethod
|
||||
def _copy_agent_drive_rows(
|
||||
cls,
|
||||
*,
|
||||
tenant_id: str,
|
||||
source_agent_id: str,
|
||||
target_agent_id: str,
|
||||
account_id: str,
|
||||
agent_soul: AgentSoulConfig,
|
||||
node_job: WorkflowNodeJobConfig | None = None,
|
||||
) -> None:
|
||||
exact_keys, prefixes = cls._drive_copy_scopes_from_agent_configs(agent_soul=agent_soul, node_job=node_job)
|
||||
predicates: list[ColumnElement[bool]] = []
|
||||
if exact_keys:
|
||||
predicates.append(AgentDriveFile.key.in_(sorted(exact_keys)))
|
||||
predicates.extend(AgentDriveFile.key.startswith(prefix) for prefix in sorted(prefixes))
|
||||
if not predicates:
|
||||
return
|
||||
|
||||
source_rows = list(
|
||||
db.session.scalars(
|
||||
select(AgentDriveFile).where(
|
||||
AgentDriveFile.tenant_id == tenant_id,
|
||||
AgentDriveFile.agent_id == source_agent_id,
|
||||
or_(*predicates),
|
||||
)
|
||||
).all()
|
||||
)
|
||||
if not source_rows:
|
||||
return
|
||||
|
||||
existing_target_keys = set(
|
||||
db.session.scalars(
|
||||
select(AgentDriveFile.key).where(
|
||||
AgentDriveFile.tenant_id == tenant_id,
|
||||
AgentDriveFile.agent_id == target_agent_id,
|
||||
AgentDriveFile.key.in_([row.key for row in source_rows]),
|
||||
)
|
||||
).all()
|
||||
)
|
||||
for row in source_rows:
|
||||
if row.key in existing_target_keys:
|
||||
continue
|
||||
db.session.add(
|
||||
AgentDriveFile(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=target_agent_id,
|
||||
key=row.key,
|
||||
file_kind=row.file_kind,
|
||||
file_id=row.file_id,
|
||||
value_owned_by_drive=row.value_owned_by_drive,
|
||||
is_skill=row.is_skill,
|
||||
skill_metadata=row.skill_metadata,
|
||||
size=row.size,
|
||||
hash=row.hash,
|
||||
mime_type=row.mime_type,
|
||||
created_by=account_id,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _drive_copy_scopes_from_agent_configs(
|
||||
*, agent_soul: AgentSoulConfig, node_job: WorkflowNodeJobConfig | None = None
|
||||
) -> tuple[set[str], set[str]]:
|
||||
from services.agent.prompt_mentions import MentionKind, parse_prompt_mentions
|
||||
from services.agent_drive_service import decode_drive_mention_ref
|
||||
|
||||
exact_keys: set[str] = set()
|
||||
prefixes: set[str] = set()
|
||||
|
||||
for mention in parse_prompt_mentions(agent_soul.prompt.system_prompt):
|
||||
if mention.kind not in {MentionKind.SKILL, MentionKind.FILE}:
|
||||
continue
|
||||
drive_key = decode_drive_mention_ref(mention.ref_id)
|
||||
if not drive_key:
|
||||
continue
|
||||
if mention.kind == MentionKind.SKILL and "/" in drive_key:
|
||||
prefixes.add(f"{drive_key.rsplit('/', 1)[0]}/")
|
||||
else:
|
||||
exact_keys.add(drive_key)
|
||||
|
||||
if node_job is not None:
|
||||
for file_ref in node_job.metadata.file_refs or []:
|
||||
if file_ref.drive_key:
|
||||
exact_keys.add(file_ref.drive_key)
|
||||
for output in node_job.declared_outputs:
|
||||
benchmark_ref = output.check.benchmark_file_ref if output.check and output.check.enabled else None
|
||||
if benchmark_ref and benchmark_ref.drive_key:
|
||||
exact_keys.add(benchmark_ref.drive_key)
|
||||
|
||||
return exact_keys, prefixes
|
||||
|
||||
@classmethod
|
||||
def _create_roster_agent_for_composer(
|
||||
cls,
|
||||
@ -955,27 +1162,42 @@ class AgentComposerService:
|
||||
agent_soul: AgentSoulConfig,
|
||||
operation: AgentConfigRevisionOperation,
|
||||
version_note: str | None,
|
||||
description: str = "",
|
||||
role: str = "",
|
||||
icon_type: AgentIconType | None = None,
|
||||
icon: str | None = None,
|
||||
icon_background: str | None = None,
|
||||
) -> Agent:
|
||||
agent = Agent(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="",
|
||||
agent_kind=AgentKind.DIFY_AGENT,
|
||||
scope=AgentScope.ROSTER,
|
||||
source=AgentSource.WORKFLOW,
|
||||
status=AgentStatus.ACTIVE,
|
||||
created_by=account_id,
|
||||
updated_by=account_id,
|
||||
)
|
||||
db.session.add(agent)
|
||||
account = cls._require_account(account_id=account_id)
|
||||
try:
|
||||
db.session.flush()
|
||||
app = AppService().create_app(
|
||||
tenant_id,
|
||||
CreateAppParams(
|
||||
name=name,
|
||||
description=description,
|
||||
mode="agent",
|
||||
agent_role=role,
|
||||
icon_type=icon_type.value if isinstance(icon_type, AgentIconType) else icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background,
|
||||
),
|
||||
account,
|
||||
)
|
||||
except IntegrityError as exc:
|
||||
db.session.rollback()
|
||||
raise AgentNameConflictError() from exc
|
||||
version = cls._create_config_version(
|
||||
|
||||
agent = AgentRosterService(db.session).get_app_backing_agent(tenant_id=tenant_id, app_id=app.id)
|
||||
if agent is None:
|
||||
raise AgentNotFoundError()
|
||||
|
||||
current_snapshot = cls._require_version(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent.id,
|
||||
version_id=agent.active_config_snapshot_id,
|
||||
)
|
||||
version = cls._update_current_version(
|
||||
current_snapshot=current_snapshot,
|
||||
account_id=account_id,
|
||||
agent_soul=agent_soul,
|
||||
operation=operation,
|
||||
@ -983,6 +1205,7 @@ class AgentComposerService:
|
||||
)
|
||||
agent.active_config_snapshot_id = version.id
|
||||
agent.active_config_has_model = agent_soul_has_model(agent_soul)
|
||||
agent.updated_by = account_id
|
||||
return agent
|
||||
|
||||
@classmethod
|
||||
@ -1111,6 +1334,13 @@ class AgentComposerService:
|
||||
raise AgentNotFoundError()
|
||||
return agent
|
||||
|
||||
@classmethod
|
||||
def _require_account(cls, *, account_id: str) -> Account:
|
||||
account = db.session.get(Account, account_id)
|
||||
if not account:
|
||||
raise ValueError("Account not found")
|
||||
return account
|
||||
|
||||
@classmethod
|
||||
def _get_agent_if_present(cls, *, tenant_id: str, agent_id: str | None) -> Agent | None:
|
||||
if not agent_id:
|
||||
|
||||
@ -17,6 +17,10 @@ class AgentArchivedError(Conflict):
|
||||
description = "Archived agent cannot be modified."
|
||||
|
||||
|
||||
class AgentVersionConflictError(Conflict):
|
||||
description = "Agent config version changed. Please reload and try again."
|
||||
|
||||
|
||||
class AgentSoulLockedError(BadRequest):
|
||||
description = "Agent Soul is locked for this workflow node."
|
||||
|
||||
|
||||
@ -837,6 +837,7 @@ class AgentRosterService:
|
||||
if agent.source == AgentSource.AGENT_APP:
|
||||
return {
|
||||
AgentConfigRevisionOperation.SAVE_NEW_VERSION,
|
||||
AgentConfigRevisionOperation.SAVE_TO_ROSTER,
|
||||
AgentConfigRevisionOperation.RESTORE_VERSION,
|
||||
}
|
||||
return {
|
||||
|
||||
@ -69,7 +69,12 @@ class AgentAppFeatureConfigService:
|
||||
|
||||
@classmethod
|
||||
def update_features(
|
||||
cls, *, app_model: App, account: Account, config: dict[str, Any], session: scoped_session
|
||||
cls,
|
||||
*,
|
||||
app_model: App,
|
||||
account: Account,
|
||||
config: dict[str, Any],
|
||||
session: scoped_session,
|
||||
) -> AppModelConfig:
|
||||
"""Persist the presentation features as a new app_model_config version.
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import TypedDict
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import delete, or_, select, update
|
||||
from sqlalchemy.orm import scoped_session
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -300,17 +301,19 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str):
|
||||
def update_app_annotation_directly(
|
||||
cls, args: UpdateAnnotationArgs, app_id: str, annotation_id: str, session: scoped_session
|
||||
):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
app = session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.get(MessageAnnotation, annotation_id)
|
||||
annotation = session.get(MessageAnnotation, annotation_id)
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
@ -326,9 +329,9 @@ class AppAnnotationService:
|
||||
annotation.content = answer
|
||||
annotation.question = question
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
app_annotation_setting = db.session.scalar(
|
||||
app_annotation_setting = session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
|
||||
)
|
||||
|
||||
@ -344,33 +347,33 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str, session: scoped_session):
|
||||
# get app info
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app = db.session.scalar(
|
||||
app = session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = db.session.get(MessageAnnotation, annotation_id)
|
||||
annotation = session.get(MessageAnnotation, annotation_id)
|
||||
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
db.session.delete(annotation)
|
||||
session.delete(annotation)
|
||||
|
||||
annotation_hit_histories = db.session.scalars(
|
||||
annotation_hit_histories = session.scalars(
|
||||
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
).all()
|
||||
if annotation_hit_histories:
|
||||
for annotation_hit_history in annotation_hit_histories:
|
||||
db.session.delete(annotation_hit_history)
|
||||
session.delete(annotation_hit_history)
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
# if annotation reply is enabled , delete annotation index
|
||||
app_annotation_setting = db.session.scalar(
|
||||
app_annotation_setting = session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
|
||||
)
|
||||
|
||||
|
||||
@ -5,11 +5,11 @@ from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from constants import AUDIO_EXTENSIONS
|
||||
from core.model_manager import ModelManager
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.enums import MessageStatus
|
||||
from models.model import App, AppMode, Message
|
||||
@ -77,6 +77,8 @@ class AudioService:
|
||||
def transcript_tts(
|
||||
cls,
|
||||
app_model: App,
|
||||
*,
|
||||
session: Session | scoped_session,
|
||||
text: str | None = None,
|
||||
voice: str | None = None,
|
||||
end_user: str | None = None,
|
||||
@ -87,7 +89,7 @@ class AudioService:
|
||||
if voice is None:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if is_draft:
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app_model)
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app_model, session=session)
|
||||
else:
|
||||
workflow = app_model.workflow
|
||||
if (
|
||||
@ -132,7 +134,7 @@ class AudioService:
|
||||
uuid.UUID(message_id)
|
||||
except ValueError:
|
||||
return None
|
||||
message = db.session.get(Message, message_id)
|
||||
message = session.get(Message, message_id)
|
||||
if message is None:
|
||||
return None
|
||||
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:
|
||||
|
||||
@ -2,17 +2,17 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str):
|
||||
data_source_api_key_bindings = db.session.scalars(
|
||||
def get_provider_auth_list(session: Session, tenant_id: str):
|
||||
data_source_api_key_bindings = session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
)
|
||||
@ -20,7 +20,7 @@ class ApiKeyAuthService:
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict[str, Any]):
|
||||
def create_provider_auth(session: Session, tenant_id: str, args: dict[str, Any]):
|
||||
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
@ -31,12 +31,12 @@ class ApiKeyAuthService:
|
||||
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
|
||||
)
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
session.add(data_source_api_key_binding)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = db.session.scalar(
|
||||
def get_auth_credentials(session: Session, tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = session.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
@ -52,16 +52,16 @@ class ApiKeyAuthService:
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = db.session.scalar(
|
||||
def delete_provider_auth(session: Session, tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = session.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.id == binding_id,
|
||||
)
|
||||
)
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
session.delete(data_source_api_key_binding)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
|
||||
@ -185,8 +185,8 @@ class EnterpriseRequest(BaseRequest):
|
||||
|
||||
if (
|
||||
not cls.rbac_base_url.startswith("http")
|
||||
or not cls.rbac_base_url.startswith("https")
|
||||
or not cls.rbac_base_url
|
||||
and not cls.rbac_base_url.startswith("https")
|
||||
and not cls.rbac_base_url
|
||||
):
|
||||
raise ValueError("ENTERPRISE_RBAC_API_URL is required when RBAC_ENABLED=true")
|
||||
|
||||
|
||||
@ -379,6 +379,9 @@ _LEGACY_WORKSPACE_EDITOR_KEYS: list[str] = [
|
||||
"snippets.create_and_modify",
|
||||
"tool.manage",
|
||||
"snippets.create_and_modify",
|
||||
"billing.view",
|
||||
"billing.subscription.manage",
|
||||
"billing.manage",
|
||||
]
|
||||
|
||||
_LEGACY_WORKSPACE_NORMAL_KEYS: list[str] = [
|
||||
@ -386,6 +389,9 @@ _LEGACY_WORKSPACE_NORMAL_KEYS: list[str] = [
|
||||
"plugin.install",
|
||||
"credential.use",
|
||||
"app_library.access",
|
||||
"billing.view",
|
||||
"billing.subscription.manage",
|
||||
"billing.manage",
|
||||
]
|
||||
|
||||
_LEGACY_WORKSPACE_DATASET_OPERATOR_KEYS: list[str] = [
|
||||
@ -834,6 +840,7 @@ class RBACService:
|
||||
options: ListOption | None = None,
|
||||
) -> Paginated[RBACRole]:
|
||||
params = (options or ListOption()).to_params({"include_owner": include_owner})
|
||||
params["dataset_operator_enabled"] = dify_config.DATASET_OPERATOR_ENABLED
|
||||
data = _inner_call(
|
||||
"GET",
|
||||
f"{_INNER_PREFIX}/roles",
|
||||
@ -1678,6 +1685,17 @@ class RBACService:
|
||||
)
|
||||
return MemberRolesResponse.model_validate(data or {})
|
||||
|
||||
@staticmethod
|
||||
def delete_rbac_bindings(tenant_id: str, account_id: str):
|
||||
data = _inner_call(
|
||||
"DELETE",
|
||||
f"{_INNER_PREFIX}/members/rbac-bindings",
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
params={"account_id": account_id},
|
||||
)
|
||||
return data
|
||||
|
||||
class CheckAccess:
|
||||
"""Call the ``/inner/api/rbac/check-access`` endpoint."""
|
||||
|
||||
|
||||
@ -42,6 +42,11 @@ class ComposerSavePayload(BaseModel):
|
||||
idempotency_key: str | None = None
|
||||
client_revision_id: str | None = None
|
||||
new_agent_name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
role: str | None = Field(default=None, max_length=255)
|
||||
icon_type: AgentIconType | None = None
|
||||
icon: str | None = Field(default=None, max_length=255)
|
||||
icon_background: str | None = Field(default=None, max_length=255)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_variant_sections(self) -> "ComposerSavePayload":
|
||||
@ -58,6 +63,12 @@ class ComposerSavePayload(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class WorkflowComposerCopyFromRosterPayload(BaseModel):
|
||||
source_agent_id: str = Field(min_length=1, max_length=255)
|
||||
source_snapshot_id: str | None = Field(default=None, max_length=255)
|
||||
idempotency_key: str | None = Field(default=None, max_length=255)
|
||||
|
||||
|
||||
class RosterAgentCreatePayload(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=255)
|
||||
mode: Literal["agent"] = "agent"
|
||||
|
||||
@ -181,6 +181,7 @@ class SystemFeatureModel(FeatureResponseModel):
|
||||
enable_creators_platform: bool = False
|
||||
enable_trial_app: bool = False
|
||||
enable_explore_banner: bool = False
|
||||
enable_learn_app: bool = True
|
||||
rbac_enabled: bool = False
|
||||
|
||||
|
||||
@ -282,6 +283,7 @@ class FeatureService:
|
||||
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
||||
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
|
||||
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
|
||||
system_features.enable_learn_app = dify_config.ENABLE_LEARN_APP
|
||||
|
||||
@classmethod
|
||||
def _fulfill_trial_models_from_env(cls) -> list[str]:
|
||||
|
||||
@ -14,8 +14,9 @@ from models.model import Account, App, Conversation, Message, MessageFeedback
|
||||
class FeedbackService:
|
||||
@staticmethod
|
||||
def export_feedbacks(
|
||||
session: Session,
|
||||
app_id: str,
|
||||
*,
|
||||
session: Session,
|
||||
from_source: str | None = None,
|
||||
rating: str | None = None,
|
||||
has_comment: bool | None = None,
|
||||
@ -28,6 +29,7 @@ class FeedbackService:
|
||||
|
||||
Args:
|
||||
app_id: Application ID
|
||||
session: Database session used to run the export query
|
||||
from_source: Filter by feedback source ('user' or 'admin')
|
||||
rating: Filter by rating ('like' or 'dislike')
|
||||
has_comment: Only include feedback with comments
|
||||
|
||||
@ -13,11 +13,11 @@ of a separate validation error.
|
||||
"""
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import scoped_session
|
||||
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes.llm.entities import ModelConfig
|
||||
from models.dataset import Dataset
|
||||
@ -38,7 +38,11 @@ from services.errors.knowledge_retrieval import (
|
||||
class InnerKnowledgeRetrievalService:
|
||||
"""Validate inner caller scope and delegate to workflow dataset retrieval."""
|
||||
|
||||
def retrieve(self, request: InnerKnowledgeRetrieveRequest) -> InnerKnowledgeRetrieveResponse:
|
||||
def retrieve(
|
||||
self,
|
||||
request: InnerKnowledgeRetrieveRequest,
|
||||
session: scoped_session,
|
||||
) -> InnerKnowledgeRetrieveResponse:
|
||||
"""Run tenant-scoped retrieval for a trusted internal caller.
|
||||
|
||||
This method only rejects caller app existence/tenant mismatches and
|
||||
@ -56,8 +60,8 @@ class InnerKnowledgeRetrievalService:
|
||||
InnerKnowledgeRetrieveDatasetTenantMismatchError:
|
||||
At least one requested dataset is outside the caller tenant.
|
||||
"""
|
||||
self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id)
|
||||
self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids)
|
||||
self._validate_caller_app(tenant_id=request.caller.tenant_id, app_id=request.caller.app_id, session=session)
|
||||
self._validate_datasets(tenant_id=request.caller.tenant_id, dataset_ids=request.dataset_ids, session=session)
|
||||
|
||||
rag = DatasetRetrieval()
|
||||
results = rag.knowledge_retrieval(request=self._to_rag_request(request))
|
||||
@ -66,8 +70,8 @@ class InnerKnowledgeRetrievalService:
|
||||
usage=InnerKnowledgeRetrieveUsage.model_validate(jsonable_encoder(rag.llm_usage)),
|
||||
)
|
||||
|
||||
def _validate_caller_app(self, *, tenant_id: str, app_id: str) -> None:
|
||||
app = db.session.scalar(select(App).where(App.id == app_id).limit(1))
|
||||
def _validate_caller_app(self, *, tenant_id: str, app_id: str, session: scoped_session) -> None:
|
||||
app = session.scalar(select(App).where(App.id == app_id).limit(1))
|
||||
if app is None:
|
||||
raise InnerKnowledgeRetrieveAppNotFoundError(f"App '{app_id}' not found")
|
||||
if app.tenant_id != tenant_id:
|
||||
@ -75,8 +79,8 @@ class InnerKnowledgeRetrievalService:
|
||||
f"App '{app_id}' does not belong to tenant '{tenant_id}'"
|
||||
)
|
||||
|
||||
def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str]) -> None:
|
||||
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
def _validate_datasets(self, *, tenant_id: str, dataset_ids: list[str], session: scoped_session) -> None:
|
||||
datasets = session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
|
||||
found_ids = {dataset.id for dataset in datasets}
|
||||
missing_ids = sorted(set(dataset_ids) - found_ids)
|
||||
|
||||
@ -2,9 +2,9 @@ import copy
|
||||
import logging
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import resolve_account_fallback
|
||||
@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
class MetadataService:
|
||||
@staticmethod
|
||||
def create_metadata(
|
||||
session: Session,
|
||||
dataset_id: str,
|
||||
metadata_args: MetadataArgs,
|
||||
current_user: Account | None = None, # TODO: the service_api is not migrated yet
|
||||
@ -33,7 +34,7 @@ class MetadataService:
|
||||
raise ValueError("Metadata name cannot exceed 255 characters.")
|
||||
current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id)
|
||||
# check if metadata name already exists
|
||||
if db.session.scalar(
|
||||
if session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(
|
||||
DatasetMetadata.tenant_id == current_tenant_id,
|
||||
@ -53,12 +54,13 @@ class MetadataService:
|
||||
name=metadata_args.name,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(metadata)
|
||||
db.session.commit()
|
||||
session.add(metadata)
|
||||
session.commit()
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def update_metadata_name(
|
||||
session: Session,
|
||||
dataset_id: str,
|
||||
metadata_id: str,
|
||||
name: str,
|
||||
@ -72,7 +74,7 @@ class MetadataService:
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
# check if metadata name already exists
|
||||
current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id)
|
||||
if db.session.scalar(
|
||||
if session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(
|
||||
DatasetMetadata.tenant_id == current_tenant_id,
|
||||
@ -87,7 +89,7 @@ class MetadataService:
|
||||
raise ValueError("Metadata name already exists in Built-in fields.")
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = db.session.scalar(
|
||||
metadata = session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
|
||||
.limit(1)
|
||||
@ -100,7 +102,7 @@ class MetadataService:
|
||||
metadata.updated_at = naive_utc_now()
|
||||
|
||||
# update related documents
|
||||
dataset_metadata_bindings = db.session.scalars(
|
||||
dataset_metadata_bindings = session.scalars(
|
||||
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
|
||||
).all()
|
||||
if dataset_metadata_bindings:
|
||||
@ -114,8 +116,8 @@ class MetadataService:
|
||||
value = doc_metadata.pop(old_name, None)
|
||||
doc_metadata[name] = value
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
return metadata
|
||||
except Exception:
|
||||
logger.exception("Update metadata name failed")
|
||||
@ -124,21 +126,21 @@ class MetadataService:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@staticmethod
|
||||
def delete_metadata(dataset_id: str, metadata_id: str):
|
||||
def delete_metadata(session: Session, dataset_id: str, metadata_id: str):
|
||||
lock_key = f"dataset_metadata_lock_{dataset_id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
metadata = db.session.scalar(
|
||||
metadata = session.scalar(
|
||||
select(DatasetMetadata)
|
||||
.where(DatasetMetadata.id == metadata_id, DatasetMetadata.dataset_id == dataset_id)
|
||||
.limit(1)
|
||||
)
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata not found.")
|
||||
db.session.delete(metadata)
|
||||
session.delete(metadata)
|
||||
|
||||
# deal related documents
|
||||
dataset_metadata_bindings = db.session.scalars(
|
||||
dataset_metadata_bindings = session.scalars(
|
||||
select(DatasetMetadataBinding).where(DatasetMetadataBinding.metadata_id == metadata_id)
|
||||
).all()
|
||||
if dataset_metadata_bindings:
|
||||
@ -151,8 +153,8 @@ class MetadataService:
|
||||
doc_metadata = copy.deepcopy(document.doc_metadata)
|
||||
doc_metadata.pop(metadata.name, None)
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
session.add(document)
|
||||
session.commit()
|
||||
return metadata
|
||||
except Exception:
|
||||
logger.exception("Delete metadata failed")
|
||||
@ -170,13 +172,13 @@ class MetadataService:
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def enable_built_in_field(dataset: Dataset):
|
||||
def enable_built_in_field(session: Session, dataset: Dataset):
|
||||
if dataset.built_in_field_enabled:
|
||||
return
|
||||
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||
db.session.add(dataset)
|
||||
session.add(dataset)
|
||||
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
if documents:
|
||||
for document in documents:
|
||||
@ -190,22 +192,22 @@ class MetadataService:
|
||||
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
session.add(document)
|
||||
dataset.built_in_field_enabled = True
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Enable built-in field failed")
|
||||
finally:
|
||||
redis_client.delete(lock_key)
|
||||
|
||||
@staticmethod
|
||||
def disable_built_in_field(dataset: Dataset):
|
||||
def disable_built_in_field(session: Session, dataset: Dataset):
|
||||
if not dataset.built_in_field_enabled:
|
||||
return
|
||||
lock_key = f"dataset_metadata_lock_{dataset.id}"
|
||||
try:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset.id, None)
|
||||
db.session.add(dataset)
|
||||
session.add(dataset)
|
||||
documents = DocumentService.get_working_documents_by_dataset_id(dataset.id)
|
||||
document_ids = []
|
||||
if documents:
|
||||
@ -220,10 +222,10 @@ class MetadataService:
|
||||
doc_metadata.pop(BuiltInField.last_update_date, None)
|
||||
doc_metadata.pop(BuiltInField.source, None)
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
session.add(document)
|
||||
document_ids.append(document.id)
|
||||
dataset.built_in_field_enabled = False
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Disable built-in field failed")
|
||||
finally:
|
||||
@ -231,6 +233,7 @@ class MetadataService:
|
||||
|
||||
@staticmethod
|
||||
def update_documents_metadata(
|
||||
session: Session,
|
||||
dataset: Dataset,
|
||||
metadata_args: MetadataOperationData,
|
||||
current_user: Account | None = None, # TODO: the service_api is not migrated yet
|
||||
@ -259,11 +262,11 @@ class MetadataService:
|
||||
doc_metadata[BuiltInField.last_update_date] = document.last_update_date.timestamp()
|
||||
doc_metadata[BuiltInField.source] = MetadataDataSource[document.data_source_type]
|
||||
document.doc_metadata = doc_metadata
|
||||
db.session.add(document)
|
||||
session.add(document)
|
||||
|
||||
# deal metadata binding (in the same transaction as the doc_metadata update)
|
||||
if not operation.partial_update:
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
delete(DatasetMetadataBinding).where(
|
||||
DatasetMetadataBinding.document_id == operation.document_id
|
||||
)
|
||||
@ -272,7 +275,7 @@ class MetadataService:
|
||||
for metadata_value in operation.metadata_list:
|
||||
# check if binding already exists
|
||||
if operation.partial_update:
|
||||
existing_binding = db.session.scalar(
|
||||
existing_binding = session.scalar(
|
||||
select(DatasetMetadataBinding)
|
||||
.where(
|
||||
DatasetMetadataBinding.document_id == operation.document_id,
|
||||
@ -290,10 +293,10 @@ class MetadataService:
|
||||
metadata_id=metadata_value.id,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(dataset_metadata_binding)
|
||||
db.session.commit()
|
||||
session.add(dataset_metadata_binding)
|
||||
session.commit()
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
session.rollback()
|
||||
logger.exception("Update documents metadata failed")
|
||||
raise
|
||||
finally:
|
||||
@ -313,14 +316,14 @@ class MetadataService:
|
||||
redis_client.set(lock_key, 1, ex=3600)
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_metadatas(dataset: Dataset):
|
||||
def get_dataset_metadatas(session: Session, dataset: Dataset):
|
||||
return {
|
||||
"doc_metadata": [
|
||||
{
|
||||
"id": item.get("id"),
|
||||
"name": item.get("name"),
|
||||
"type": item.get("type"),
|
||||
"count": db.session.scalar(
|
||||
"count": session.scalar(
|
||||
select(func.count(DatasetMetadataBinding.id)).where(
|
||||
DatasetMetadataBinding.metadata_id == item.get("id"),
|
||||
DatasetMetadataBinding.dataset_id == dataset.id,
|
||||
|
||||
@ -91,4 +91,4 @@ class OAuthServerService:
|
||||
|
||||
user_id_str = user_account_id.decode("utf-8")
|
||||
|
||||
return AccountService.load_user(user_id_str)
|
||||
return AccountService.load_user(user_id_str, db.session)
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Any, override
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
|
||||
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
|
||||
@ -25,6 +26,11 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
result = self.fetch_recommended_apps_from_builtin(language)
|
||||
return result
|
||||
|
||||
@override
|
||||
def get_learn_dify_apps(self, language: str):
|
||||
result = DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db(language)
|
||||
return result
|
||||
|
||||
@override
|
||||
def get_recommend_app_detail(self, app_id: str):
|
||||
result = self.fetch_recommended_app_detail_from_builtin(app_id)
|
||||
|
||||
@ -49,6 +49,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
result = self.fetch_recommended_apps_from_db(language)
|
||||
return result
|
||||
|
||||
@override
|
||||
def get_learn_dify_apps(self, language: str) -> RecommendedAppsResultDict:
|
||||
result = self.fetch_learn_dify_apps_from_db(language)
|
||||
return result
|
||||
|
||||
@override
|
||||
def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None:
|
||||
result = self.fetch_recommended_app_detail_from_db(app_id)
|
||||
|
||||
@ -6,6 +6,8 @@ class RecommendAppRetrievalBase(Protocol):
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> Any: ...
|
||||
|
||||
def get_learn_dify_apps(self, language: str) -> Any: ...
|
||||
|
||||
def get_recommend_app_detail(self, app_id: str) -> Any: ...
|
||||
|
||||
def get_type(self) -> str: ...
|
||||
|
||||
@ -2,15 +2,28 @@ import logging
|
||||
from typing import Any, override
|
||||
|
||||
import httpx
|
||||
from flask import has_request_context, request
|
||||
|
||||
from configs import dify_config
|
||||
from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval
|
||||
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
|
||||
from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
|
||||
from services.recommend_app.recommend_app_type import RecommendAppType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _current_origin_headers() -> dict[str, str]:
|
||||
origin = request.headers.get("Origin") if has_request_context() else None
|
||||
if origin:
|
||||
return {"Origin": origin}
|
||||
|
||||
console_web_url = getattr(dify_config, "CONSOLE_WEB_URL", "")
|
||||
if not isinstance(console_web_url, str) or not console_web_url:
|
||||
return {}
|
||||
return {"Origin": console_web_url}
|
||||
|
||||
|
||||
class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
"""
|
||||
Retrieval recommended app from dify official.
|
||||
@ -37,6 +50,15 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language)
|
||||
return result
|
||||
|
||||
@override
|
||||
def get_learn_dify_apps(self, language: str):
|
||||
try:
|
||||
result = self.fetch_learn_dify_apps_from_dify_official(language)
|
||||
except Exception as e:
|
||||
logger.warning("fetch learn dify apps from dify official failed: %s, switch to database.", e)
|
||||
result = DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db(language)
|
||||
return result
|
||||
|
||||
@override
|
||||
def get_type(self) -> str:
|
||||
return RecommendAppType.REMOTE
|
||||
@ -50,7 +72,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/apps/{app_id}"
|
||||
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
response = httpx.get(url, headers=_current_origin_headers(), timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
data: dict[str, Any] = response.json()
|
||||
@ -65,9 +87,25 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/apps?language={language}"
|
||||
response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
response = httpx.get(url, headers=_current_origin_headers(), timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
|
||||
|
||||
result: dict[str, Any] = response.json()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def fetch_learn_dify_apps_from_dify_official(cls, language: str):
|
||||
"""
|
||||
Fetch Learn Dify apps from dify official.
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f"{domain}/apps/learn-dify?language={language}"
|
||||
response = httpx.get(url, headers=_current_origin_headers(), timeout=httpx.Timeout(10.0, connect=3.0))
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"fetch learn dify apps failed, status code: {response.status_code}")
|
||||
|
||||
result: dict[str, Any] = response.json()
|
||||
return result
|
||||
|
||||
@ -6,7 +6,6 @@ from sqlalchemy.orm import scoped_session
|
||||
from configs import dify_config
|
||||
from models.model import AccountTrialAppRecord, TrialApp
|
||||
from services.feature_service import FeatureService
|
||||
from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval
|
||||
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
|
||||
|
||||
|
||||
@ -38,11 +37,13 @@ class RecommendedAppService:
|
||||
@classmethod
|
||||
def get_learn_dify_apps(cls, session: scoped_session, language: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get database-backed recommended apps marked as Learn Dify.
|
||||
Get recommended apps marked for the Learn Dify section.
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
result = DatabaseRecommendAppRetrieval.fetch_learn_dify_apps_from_db(language)
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
|
||||
result = retrieval_instance.get_learn_dify_apps(language)
|
||||
|
||||
if FeatureService.get_system_features().enable_trial_app:
|
||||
for app in result["recommended_apps"]:
|
||||
|
||||
@ -35,7 +35,7 @@ class WebAppAuthService:
|
||||
@staticmethod
|
||||
def authenticate(email: str, password: str) -> Account:
|
||||
"""authenticate account with email and password"""
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
|
||||
@ -55,7 +55,7 @@ class WebAppAuthService:
|
||||
|
||||
@classmethod
|
||||
def get_user_through_email(cls, email: str):
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
account = AccountService.get_account_by_email_with_case_fallback(db.session, email)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
@ -142,7 +142,7 @@ class WorkflowService:
|
||||
return db.session.execute(stmt).scalar_one()
|
||||
|
||||
def get_draft_workflow(
|
||||
self, app_model: App, workflow_id: str | None = None, session: Session | None = None
|
||||
self, app_model: App, workflow_id: str | None = None, session: Session | scoped_session | None = None
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Get draft workflow
|
||||
@ -169,7 +169,7 @@ class WorkflowService:
|
||||
return workflow
|
||||
|
||||
def get_published_workflow_by_id(
|
||||
self, app_model: App, workflow_id: str, session: Session | None = None
|
||||
self, app_model: App, workflow_id: str, session: Session | scoped_session | None = None
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
fetch published workflow by workflow_id
|
||||
|
||||
@ -36,7 +36,9 @@ class WorkspaceService:
|
||||
feature = FeatureService.get_features(tenant.id, exclude_vector_space=True)
|
||||
can_replace_logo = feature.can_replace_logo
|
||||
|
||||
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
|
||||
if can_replace_logo and TenantService.has_roles(
|
||||
tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], session=db.session
|
||||
):
|
||||
base_url = dify_config.FILES_URL
|
||||
replace_webapp_logo = (
|
||||
f"{base_url}/files/workspaces/{tenant.id}/webapp-logo"
|
||||
|
||||
@ -84,6 +84,7 @@ def setup_account(request) -> Generator[Account, None, None]:
|
||||
password=secrets.token_hex(16),
|
||||
ip_address="localhost",
|
||||
language="en-US",
|
||||
session=db.session,
|
||||
)
|
||||
|
||||
with _CACHED_APP.test_request_context():
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Controller integration tests for API key data source auth routes."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
@ -85,7 +85,7 @@ def test_create_binding_successful(
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
create_auth.assert_called_once_with(tenant_id, payload)
|
||||
create_auth.assert_called_once_with(ANY, tenant_id, payload)
|
||||
|
||||
|
||||
def test_create_binding_failure(
|
||||
|
||||
@ -270,10 +270,7 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
|
||||
result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Case@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -165,10 +165,7 @@ def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com")
|
||||
result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Mixed@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
@ -494,10 +494,7 @@ class TestAccountGeneration:
|
||||
second_result.scalar_one_or_none.return_value = expected_account
|
||||
mock_session.execute.side_effect = [first_result, second_result]
|
||||
|
||||
with patch("services.account_service.session_factory") as mock_factory:
|
||||
mock_factory.create_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_factory.create_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com")
|
||||
result = AccountService.get_account_by_email_with_case_fallback(mock_session, "Case@Test.com")
|
||||
|
||||
assert result is expected_account
|
||||
assert mock_session.execute.call_count == 2
|
||||
@ -551,6 +548,7 @@ class TestAccountGeneration:
|
||||
provider="github",
|
||||
language="en-US",
|
||||
timezone=None,
|
||||
session=ANY,
|
||||
)
|
||||
else:
|
||||
mock_register_service.register.assert_not_called()
|
||||
@ -584,6 +582,7 @@ class TestAccountGeneration:
|
||||
provider="github",
|
||||
language="en-US",
|
||||
timezone=None,
|
||||
session=ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@ -615,6 +614,7 @@ class TestAccountGeneration:
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone="Asia/Shanghai",
|
||||
session=ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@ -646,6 +646,7 @@ class TestAccountGeneration:
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone=None,
|
||||
session=ANY,
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
@ -676,7 +677,7 @@ class TestAccountGeneration:
|
||||
|
||||
assert result == mock_account
|
||||
assert oauth_new_user is False
|
||||
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
|
||||
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace", session=ANY)
|
||||
mock_tenant_service.create_tenant_member.assert_called_once_with(
|
||||
mock_new_tenant, mock_account, ANY, role="owner"
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterator
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Literal
|
||||
from unittest.mock import patch
|
||||
@ -12,7 +12,6 @@ from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx
|
||||
from models import Account, Tenant
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -46,20 +45,25 @@ def make_account(db_session_with_containers: Session) -> Callable[..., Account]:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
if with_owner_tenant:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
account, name=fake.company(), session=db_session_with_containers
|
||||
)
|
||||
return account
|
||||
|
||||
return _make
|
||||
|
||||
|
||||
def add_tenant_for_account(account: Account, *, role: str = "normal", name: str = "Second WS") -> Tenant:
|
||||
def add_tenant_for_account(
|
||||
account: Account, *, session: Session, role: str = "normal", name: str = "Second WS"
|
||||
) -> Tenant:
|
||||
"""Create an additional tenant and join ``account`` to it (real service calls)."""
|
||||
with patch("services.account_service.FeatureService") as mock_feature_service:
|
||||
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
|
||||
tenant = TenantService.create_tenant(name=name)
|
||||
TenantService.create_tenant_member(tenant, account, db.session, role=role)
|
||||
tenant = TenantService.create_tenant(name=name, session=session)
|
||||
TenantService.create_tenant_member(tenant, account, session, role=role)
|
||||
return tenant
|
||||
|
||||
|
||||
@ -93,7 +97,7 @@ def account_auth_context(
|
||||
*,
|
||||
token_id: uuid.UUID,
|
||||
client_id: str = "integration-cli",
|
||||
) -> Iterator[AuthContext]:
|
||||
) -> Generator[AuthContext]:
|
||||
"""Publish an account ``AuthContext`` for handlers that read ``get_auth_ctx()``.
|
||||
|
||||
The auth pipeline normally sets this ContextVar; the integration suite
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Callable
|
||||
from inspect import unwrap
|
||||
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.openapi.account import AccountApi
|
||||
from models import Account
|
||||
@ -34,11 +35,13 @@ class TestAccountInfo:
|
||||
# the only workspace the account belongs to.
|
||||
assert result.default_workspace_id == owner_tenant.id
|
||||
|
||||
def test_lists_all_joined_workspaces(self, app: Flask, make_account: Callable[..., Account]) -> None:
|
||||
def test_lists_all_joined_workspaces(
|
||||
self, app: Flask, db_session_with_containers: Session, make_account: Callable[..., Account]
|
||||
) -> None:
|
||||
account = make_account()
|
||||
owner_tenant = account.current_tenant
|
||||
assert owner_tenant is not None
|
||||
second = add_tenant_for_account(account, role="normal", name="Second WS")
|
||||
second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS")
|
||||
|
||||
api = AccountApi()
|
||||
with app.test_request_context("/openapi/v1/account"):
|
||||
|
||||
@ -81,8 +81,9 @@ def _app_and_account(db_session: Session, *, mode: str = "chat") -> tuple[App, A
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session)
|
||||
tenant = account.current_tenant
|
||||
assert tenant is not None
|
||||
app_args = CreateAppParams(
|
||||
|
||||
@ -41,7 +41,7 @@ class TestWorkspacesList:
|
||||
account = make_account()
|
||||
owner_tenant = account.current_tenant
|
||||
assert owner_tenant is not None
|
||||
second = add_tenant_for_account(account, role="normal", name="Second WS")
|
||||
second = add_tenant_for_account(account, session=db_session_with_containers, role="normal", name="Second WS")
|
||||
|
||||
api = WorkspacesApi()
|
||||
with app.test_request_context("/openapi/v1/workspaces"):
|
||||
@ -90,7 +90,9 @@ class TestWorkspaceSwitch:
|
||||
account = make_account()
|
||||
owner_tenant = account.current_tenant
|
||||
assert owner_tenant is not None
|
||||
target = add_tenant_for_account(account, role="normal", name="Switch Target")
|
||||
target = add_tenant_for_account(
|
||||
account, session=db_session_with_containers, role="normal", name="Switch Target"
|
||||
)
|
||||
|
||||
api = WorkspaceSwitchApi()
|
||||
with app.test_request_context(f"/openapi/v1/workspaces/{target.id}/switch", method="POST"):
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -57,7 +57,7 @@ class TestForgotPasswordSendEmailApi:
|
||||
response = ForgotPasswordSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-123"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com")
|
||||
mock_get_account.assert_called_once_with(ANY, "User@Example.com")
|
||||
mock_send_mail.assert_called_once_with(account=mock_account, email="user@example.com", language="zh-Hans")
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_rate_limit.assert_called_once_with("127.0.0.1")
|
||||
@ -177,7 +177,7 @@ class TestForgotPasswordResetApi:
|
||||
response = ForgotPasswordResetApi().post()
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_get_account.assert_called_once_with("User@Example.com")
|
||||
mock_get_account.assert_called_once_with(ANY, "User@Example.com")
|
||||
mock_update_account.assert_called_once()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
|
||||
|
||||
@ -27,8 +27,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
@ -88,8 +89,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
@ -141,8 +143,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
@ -194,8 +197,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
@ -257,8 +261,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
@ -291,8 +296,11 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
account1, name=fake.company(), session=db_session_with_containers
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company())
|
||||
tenant1 = account1.current_tenant
|
||||
|
||||
account2 = AccountService.create_account(
|
||||
@ -300,8 +308,11 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(
|
||||
account2, name=fake.company(), session=db_session_with_containers
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company())
|
||||
tenant2 = account2.current_tenant
|
||||
|
||||
# Create dataset for tenant1
|
||||
@ -367,8 +378,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Don't create any datasets
|
||||
@ -391,8 +403,9 @@ class TestGetAvailableDatasetsIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create multiple datasets
|
||||
@ -452,8 +465,9 @@ class TestKnowledgeRetrievalIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
@ -520,8 +534,9 @@ class TestKnowledgeRetrievalIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset but no documents
|
||||
@ -568,8 +583,9 @@ class TestKnowledgeRetrievalIntegration:
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=generate_valid_password(fake),
|
||||
session=db_session_with_containers,
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company(), session=db_session_with_containers)
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
|
||||
@ -51,7 +51,7 @@ class TestApiKeyAuthService:
|
||||
self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
|
||||
result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id)
|
||||
|
||||
assert len(result) >= 1
|
||||
tenant_results = [r for r in result if r.tenant_id == tenant_id]
|
||||
@ -61,7 +61,7 @@ class TestApiKeyAuthService:
|
||||
def test_get_provider_auth_list_empty(
|
||||
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id
|
||||
):
|
||||
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
|
||||
result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id)
|
||||
|
||||
tenant_results = [r for r in result if r.tenant_id == tenant_id]
|
||||
assert tenant_results == []
|
||||
@ -74,7 +74,7 @@ class TestApiKeyAuthService:
|
||||
)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
|
||||
result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id)
|
||||
|
||||
tenant_results = [r for r in result if r.tenant_id == tenant_id]
|
||||
assert tenant_results == []
|
||||
@ -95,7 +95,7 @@ class TestApiKeyAuthService:
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123"
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args)
|
||||
|
||||
mock_factory.assert_called_once()
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
@ -118,7 +118,7 @@ class TestApiKeyAuthService:
|
||||
mock_auth_instance.validate_credentials.return_value = False
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all()
|
||||
@ -142,7 +142,7 @@ class TestApiKeyAuthService:
|
||||
|
||||
original_key = mock_args["credentials"]["config"]["api_key"]
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args)
|
||||
|
||||
assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123"
|
||||
assert mock_args["credentials"]["config"]["api_key"] != original_key
|
||||
@ -166,14 +166,14 @@ class TestApiKeyAuthService:
|
||||
)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
|
||||
result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider)
|
||||
|
||||
assert result == mock_credentials
|
||||
|
||||
def test_get_auth_credentials_not_found(
|
||||
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, category, provider
|
||||
):
|
||||
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
|
||||
result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider)
|
||||
|
||||
assert result is None
|
||||
|
||||
@ -190,7 +190,7 @@ class TestApiKeyAuthService:
|
||||
)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
|
||||
result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider)
|
||||
|
||||
assert result == special_credentials
|
||||
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
|
||||
@ -204,7 +204,7 @@ class TestApiKeyAuthService:
|
||||
binding_id = binding.id
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id)
|
||||
ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, binding_id)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first()
|
||||
@ -214,7 +214,7 @@ class TestApiKeyAuthService:
|
||||
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id
|
||||
):
|
||||
# Should not raise when binding not found
|
||||
ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4()))
|
||||
ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, str(uuid4()))
|
||||
|
||||
def test_validate_api_key_auth_args_success(self, mock_args):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
|
||||
@ -288,16 +288,16 @@ class TestApiKeyAuthService:
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
mock_encrypter.encrypt_token.return_value = "encrypted_key"
|
||||
|
||||
with patch("services.auth.api_key_auth_service.db.session") as mock_session:
|
||||
mock_session.commit.side_effect = Exception("Database error")
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
mock_session = MagicMock()
|
||||
mock_session.commit.side_effect = Exception("Database error")
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
ApiKeyAuthService.create_provider_auth(mock_session, tenant_id, mock_args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
def test_create_provider_auth_factory_exception(self, mock_factory: MagicMock, tenant_id, mock_args):
|
||||
mock_factory.side_effect = Exception("Factory error")
|
||||
with pytest.raises(Exception, match="Factory error"):
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
@ -307,7 +307,7 @@ class TestApiKeyAuthService:
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
mock_encrypter.encrypt_token.side_effect = Exception("Encryption error")
|
||||
with pytest.raises(Exception, match="Encryption error"):
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args)
|
||||
|
||||
def test_validate_api_key_auth_args_none_input(self):
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user