merge main

This commit is contained in:
Joel 2026-05-12 18:23:24 +08:00
commit 8b664680aa
627 changed files with 9900 additions and 8344 deletions

View File

@ -63,7 +63,7 @@ pnpm analyze-component <path> --json
```typescript
// ❌ Before: Complex state logic in component
const Configuration: FC = () => {
function Configuration() {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
@ -85,7 +85,7 @@ export const useModelConfig = (appId: string) => {
}
// Component becomes cleaner
const Configuration: FC = () => {
function Configuration() {
const { modelConfig, setModelConfig } = useModelConfig(appId)
return <div>...</div>
}
@ -189,8 +189,6 @@ const Template = useMemo(() => {
**Dify Convention**:
- This skill is for component decomposition, not query/mutation design.
- When refactoring data fetching, follow `web/AGENTS.md`.
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
- Do not introduce deprecated `useInvalid` / `useReset`.
- Do not add thin passthrough `useQuery` wrappers during refactoring; only extract a custom hook when it truly orchestrates multiple queries/mutations or shared derived state.

View File

@ -60,8 +60,10 @@ const Template = useMemo(() => {
**After** (complexity: ~3):
```typescript
import type { ComponentType } from 'react'
// Define lookup table outside component
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, ComponentType<TemplateProps>>> = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,

View File

@ -65,10 +65,10 @@ interface ConfigurationHeaderProps {
onPublish: () => void
}
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
function ConfigurationHeader({
isAdvancedMode,
onPublish,
}) => {
}: ConfigurationHeaderProps) {
const { t } = useTranslation()
return (
@ -136,7 +136,7 @@ const AppInfo = () => {
}
// ✅ After: Separate view components
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
function AppInfoExpanded({ appDetail, onAction }: AppInfoViewProps) {
return (
<div className="expanded">
{/* Clean, focused expanded view */}
@ -144,7 +144,7 @@ const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
)
}
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
function AppInfoCollapsed({ appDetail, onAction }: AppInfoViewProps) {
return (
<div className="collapsed">
{/* Clean, focused collapsed view */}
@ -203,12 +203,12 @@ interface AppInfoModalsProps {
onSuccess: () => void
}
const AppInfoModals: FC<AppInfoModalsProps> = ({
function AppInfoModals({
appDetail,
activeModal,
onClose,
onSuccess,
}) => {
}: AppInfoModalsProps) {
const handleEdit = async (data) => { /* logic */ }
const handleDuplicate = async (data) => { /* logic */ }
const handleDelete = async () => { /* logic */ }
@ -296,7 +296,7 @@ interface OperationItemProps {
onAction: (id: string) => void
}
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
function OperationItem({ operation, onAction }: OperationItemProps) {
return (
<div className="operation-item">
<span className="icon">{operation.icon}</span>
@ -435,7 +435,7 @@ interface ChildProps {
onSubmit: () => void
}
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
function Child({ value, onChange, onSubmit }: ChildProps) {
return (
<div>
<input value={value} onChange={e => onChange(e.target.value)} />

View File

@ -112,13 +112,13 @@ export const useModelConfig = ({
```typescript
// Before: 50+ lines of state management
const Configuration: FC = () => {
function Configuration() {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
// ... lots of related state and effects
}
// After: Clean component
const Configuration: FC = () => {
function Configuration() {
const {
modelConfig,
setModelConfig,
@ -159,8 +159,6 @@ const Configuration: FC = () => {
When hook extraction touches query or mutation code, do not use this reference as the source of truth for data-layer patterns.
- Follow `web/AGENTS.md` first.
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
- Do not introduce deprecated `useInvalid` / `useReset`.
- Do not extract thin passthrough `useQuery` hooks; only extract orchestration hooks.

View File

@ -23,7 +23,7 @@ Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
5. Re-check official Playwright or Cucumber docs with the available documentation tools before introducing a new framework pattern.
## Local Rules

View File

@ -9,18 +9,18 @@ Category: Performance
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
## Complex prop memoization
## Complex prop stability
IsUrgent: True
IsUrgent: False
Category: Performance
### Description
Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders.
Only require stable object, array, or map props when there is a clear reason: the child is memoized, the value participates in effect/query dependencies, the value is part of a stable-reference API contract, or profiling/local behavior shows avoidable re-renders. Do not request `useMemo` for every inline object by default; `how-to-write-component` treats memoization as a targeted optimization.
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
Wrong:
Risky:
```tsx
<HeavyComp
@ -31,7 +31,7 @@ Wrong:
/>
```
Right:
Better when stable identity matters:
```tsx
const config = useMemo(() => ({

View File

@ -1,46 +0,0 @@
---
name: frontend-query-mutation
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
---
# Frontend Query & Mutation
## Intent
- Keep contract as the single source of truth in `web/contract/*`.
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough.
- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination.
- Keep abstractions minimal to preserve TypeScript inference.
## Workflow
1. Identify the change surface.
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
- Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations.
- Read both references when a task spans contract shape and runtime behavior.
2. Implement the smallest abstraction that fits the task.
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
- Extract a small shared query helper only when multiple call sites share the same extra options.
- Create or keep feature hooks only for real orchestration or shared domain behavior.
- When touching thin `web/service/use-*` wrappers, migrate them away when feasible.
3. Preserve Dify conventions.
- Keep contract inputs in `{ params, query?, body? }` shape.
- Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior.
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
## Files Commonly Touched
- `web/contract/console/*.ts`
- `web/contract/marketplace.ts`
- `web/contract/router.ts`
- `web/service/client.ts`
- legacy `web/service/use-*.ts` files when migrating wrappers away
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
## References
- Use `references/contract-patterns.md` for contract shape, router registration, query and mutation helpers, and anti-patterns that degrade inference.
- Use `references/runtime-rules.md` for conditional queries, invalidation, `mutate` versus `mutateAsync`, and legacy migration rules.
Treat this skill as the single query and mutation entry point for Dify frontend work. Keep detailed rules in the reference files instead of duplicating them in project docs.

View File

@ -1,4 +0,0 @@
interface:
display_name: "Frontend Query & Mutation"
short_description: "Dify TanStack Query, oRPC, and default option patterns"
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations."

View File

@ -1,129 +0,0 @@
# Contract Patterns
## Table of Contents
- Intent
- Minimal structure
- Core workflow
- Query usage decision rule
- Mutation usage decision rule
- Thin hook decision rule
- Anti-patterns
- Contract rules
- Type export
## Intent
- Keep contract as the single source of truth in `web/contract/*`.
- Default query usage to call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
- Keep abstractions minimal and preserve TypeScript inference.
## Minimal Structure
```text
web/contract/
├── base.ts
├── router.ts
├── marketplace.ts
└── console/
├── billing.ts
└── ...other domains
web/service/client.ts
```
## Core Workflow
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`.
- Use `base.route({...}).output(type<...>())` as the baseline.
- Add `.input(type<...>())` only when the request has `params`, `query`, or `body`.
- For `GET` without input, omit `.input(...)`; do not use `.input(type<unknown>())`.
2. Register contract in `web/contract/router.ts`.
- Import directly from domain files and nest by API prefix.
3. Consume from UI call sites via oRPC query utilities.
```typescript
import { useQuery } from '@tanstack/react-query'
import { consoleQuery } from '@/service/client'
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
staleTime: 5 * 60 * 1000,
throwOnError: true,
select: invoice => invoice.url,
}))
```
## Query Usage Decision Rule
1. Default to direct `*.queryOptions(...)` usage at the call site.
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
3. Create or keep feature hooks only for orchestration.
- Combine multiple queries or mutations.
- Share domain-level derived state or invalidation helpers.
- Prefer `web/features/{domain}/hooks/*` for feature-owned workflows.
4. Treat `web/service/use-{domain}.ts` as legacy.
- Do not create new thin service wrappers for oRPC contracts.
- When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough.
```typescript
const invoicesBaseQueryOptions = () =>
consoleQuery.billing.invoices.queryOptions({ retry: false })
const invoiceQuery = useQuery({
...invoicesBaseQueryOptions(),
throwOnError: true,
})
```
## Mutation Usage Decision Rule
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
```typescript
const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions())
```
## Thin Hook Decision Rule
Remove thin hooks when they only rename a single oRPC query or mutation helper.
Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API.
Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`.
Use:
```typescript
const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions())
```
Keep:
```typescript
const applyTagBindingsMutation = useApplyTagBindingsMutation()
```
`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough.
## Anti-Patterns
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
- Do not create thin `use-*` passthrough hooks for a single endpoint.
- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`.
- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs.
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
## Contract Rules
- Input structure: always use `{ params, query?, body? }`.
- No-input `GET`: omit `.input(...)`; do not use `.input(type<unknown>())`.
- Path params: use `{paramName}` in the path and match it in the `params` object.
- Router nesting: group by API prefix, for example `/billing/*` becomes `billing: {}`.
- No barrel files: import directly from specific files.
- Types: import from `@/types/` and use the `type<T>()` helper.
- Mutations: prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults, filtering, and devtools.
## Type Export
```typescript
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
```

View File

@ -1,172 +0,0 @@
# Runtime Rules
## Table of Contents
- Conditional queries
- oRPC default options
- Cache invalidation
- Key API guide
- `mutate` vs `mutateAsync`
- Legacy migration
## Conditional Queries
Prefer contract-shaped `queryOptions(...)`.
When required input is missing, prefer `input: skipToken` instead of placeholder params or non-null assertions.
Use `enabled` only for extra business gating after the input itself is already valid.
```typescript
import { skipToken, useQuery } from '@tanstack/react-query'
// Disable the query by skipping input construction.
function useAccessMode(appId: string | undefined) {
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
input: appId
? { params: { appId } }
: skipToken,
}))
}
// Avoid runtime-only guards that bypass type checking.
function useBadAccessMode(appId: string | undefined) {
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
input: { params: { appId: appId! } },
enabled: !!appId,
}))
}
```
## oRPC Default Options
Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation.
Place defaults at the query utility creation point in `web/service/client.ts`:
```typescript
export const consoleQuery = createTanstackQueryUtils(consoleClient, {
path: ['console'],
experimental_defaults: {
tags: {
create: {
mutationOptions: {
onSuccess: (tag, _variables, _result, context) => {
context.client.setQueryData(
consoleQuery.tags.list.queryKey({
input: {
query: {
type: tag.type,
},
},
}),
(oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags,
)
},
},
},
},
},
})
```
Rules:
- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders.
- Do not create a wrapper function solely to host `createTanstackQueryUtils`.
- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`.
- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility.
- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation.
## Cache Invalidation
Bind shared invalidation in oRPC defaults when it is tied to a contract operation.
Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default.
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
Use:
- `.key()` for namespace or prefix invalidation
- `.queryKey(...)` only for exact cache reads or writes such as `getQueryData` and `setQueryData`
- `queryClient.invalidateQueries(...)` in mutation `onSuccess`
Do not use deprecated `useInvalid` from `use-base.ts`.
```typescript
// Feature orchestration owns cache invalidation only when defaults are not enough.
export const useUpdateAccessMode = () => {
const queryClient = useQueryClient()
return useMutation(consoleQuery.accessControl.updateAccessMode.mutationOptions({
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
})
},
}))
}
// Component only adds UI behavior.
updateAccessMode({ appId, mode }, {
onSuccess: () => toast.success('...'),
})
// Avoid putting invalidation knowledge in the component.
mutate({ appId, mode }, {
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
})
},
})
```
## Key API Guide
- `.key(...)`
- Use for partial matching operations.
- Prefer it for invalidation, refetch, and cancel patterns.
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
- `.queryKey(...)`
- Use for a specific query's full key.
- Prefer it for exact cache addressing and direct reads or writes.
- `.mutationKey(...)`
- Use for a specific mutation's full key.
- Prefer it for mutation defaults registration, mutation-status filtering, and devtools grouping.
## `mutate` vs `mutateAsync`
Prefer `mutate` by default.
Use `mutateAsync` only when Promise semantics are truly required, such as parallel mutations or sequential steps with result dependencies.
Rules:
- Event handlers should usually call `mutate(...)` with `onSuccess` or `onError`.
- Every `await mutateAsync(...)` must be wrapped in `try/catch`.
- Do not use `mutateAsync` when callbacks already express the flow clearly.
```typescript
// Default case.
mutation.mutate(data, {
onSuccess: result => router.push(result.url),
})
// Promise semantics are required.
try {
const order = await createOrder.mutateAsync(orderData)
await confirmPayment.mutateAsync({ orderId: order.id, token })
router.push(`/orders/${order.id}`)
}
catch (error) {
toast.error(error instanceof Error ? error.message : 'Unknown error')
}
```
## Legacy Migration
When touching old code, migrate it toward these rules:
| Old pattern | New pattern |
|---|---|
| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration |
| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook |
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |

View File

@ -5,7 +5,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com
# Dify Frontend Testing Skill
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
This skill enables Codex to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`).
@ -24,35 +24,27 @@ Apply this skill when the user:
**Do NOT apply** when:
- User is asking about backend/API tests (Python/pytest)
- User is asking about E2E tests (Playwright/Cypress)
- User is asking about E2E tests (Cucumber + Playwright under `e2e/`)
- User is only asking conceptual questions without code context
## Quick Reference
### Tech Stack
| Tool | Version | Purpose |
|------|---------|---------|
| Vitest | 4.0.16 | Test runner |
| React Testing Library | 16.0 | Component testing |
| jsdom | - | Test environment |
| nock | 14.0 | HTTP mocking |
| TypeScript | 5.x | Type safety |
### Key Commands
Run these commands from `web/`. From the repository root, prefix them with `pnpm -C web`.
```bash
# Run all tests
pnpm test
# Watch mode
pnpm test:watch
pnpm test --watch
# Run specific file
pnpm test path/to/file.spec.tsx
# Generate coverage report
pnpm test:coverage
pnpm test --coverage
# Analyze component complexity
pnpm analyze-component <path>
@ -228,7 +220,10 @@ Every test should clearly separate:
### 2. Black-Box Testing
- Test observable behavior, not implementation details
- Use semantic queries (getByRole, getByLabelText)
- Use semantic queries (`getByRole` with accessible `name`, `getByLabelText`, `getByPlaceholderText`, `getByText`, and scoped `within(...)`)
- Treat `getByTestId` as a last resort. If a control cannot be found by role/name, label, landmark, or dialog scope, fix the component accessibility first instead of adding or relying on `data-testid`.
- Remove production `data-testid` attributes when semantic selectors can cover the behavior. Keep them only for non-visual mocked boundaries, editor/browser shims such as Monaco, canvas/chart output, or third-party widgets with no accessible DOM in the test environment.
- Do not assert decorative icons by test id. Assert the named control that contains them, or mark decorative icons `aria-hidden`.
- Avoid testing internal state directly
- **Prefer pattern matching over hardcoded strings** in assertions:

View File

@ -56,7 +56,7 @@ See [Zustand Store Testing](#zustand-store-testing) section for full details.
| Location | Purpose |
|----------|---------|
| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `next/image`, `zustand`) |
| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `zustand`, clipboard, FloatingPortal, Monaco, localStorage`) |
| `web/__mocks__/zustand.ts` | Zustand mock implementation (auto-resets stores after each test) |
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
| Test file | Test-specific mocks, inline with `vi.mock()` |
@ -216,28 +216,21 @@ describe('Component', () => {
})
```
### 5. HTTP Mocking with Nock
### 5. HTTP and `fetch` Mocking
```typescript
import nock from 'nock'
const GITHUB_HOST = 'https://api.github.com'
const GITHUB_PATH = '/repos/owner/repo'
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
return nock(GITHUB_HOST)
.get(GITHUB_PATH)
.delay(delayMs)
.reply(status, body)
}
describe('GithubComponent', () => {
afterEach(() => {
nock.cleanAll()
beforeEach(() => {
vi.clearAllMocks()
})
it('should display repo info', async () => {
mockGithubApi(200, { name: 'dify', stars: 1000 })
vi.mocked(globalThis.fetch).mockResolvedValueOnce(
new Response(JSON.stringify({ name: 'dify', stars: 1000 }), {
status: 200,
headers: { 'Content-Type': 'application/json' },
}),
)
render(<GithubComponent />)
@ -247,7 +240,12 @@ describe('GithubComponent', () => {
})
it('should handle API error', async () => {
mockGithubApi(500, { message: 'Server error' })
vi.mocked(globalThis.fetch).mockResolvedValueOnce(
new Response(JSON.stringify({ message: 'Server error' }), {
status: 500,
headers: { 'Content-Type': 'application/json' },
}),
)
render(<GithubComponent />)
@ -258,6 +256,8 @@ describe('GithubComponent', () => {
})
```
Prefer mocking `@/service/*` modules or spying on `global.fetch` / `ky` clients with deterministic responses. Do not introduce an HTTP interception dependency such as `nock` or MSW unless it is already declared in the workspace or adding it is part of the task.
### 6. Context Providers
```typescript
@ -332,7 +332,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
1. **Don't mock Zustand store modules** - Use real stores with `setState()`
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
1. Don't forget to clean up nock after each test
1. Don't leave HTTP mocks or service mock state leaking between tests
1. Don't use `any` types in mocks without necessity
### Mock Decision Tree

View File

@ -227,12 +227,12 @@ Failing tests compound:
**Fix failures immediately before proceeding.**
## Integration with Claude's Todo Feature
## Integration with Codex's Todo Feature
When using Claude for multi-file testing:
When using Codex for multi-file testing:
1. **Ask Claude to create a todo list** before starting
1. **Request one file at a time** or ensure Claude processes incrementally
1. **Create a todo list** before starting
1. **Process one file at a time**
1. **Verify each test passes** before asking for the next
1. **Mark todos complete** as you progress

View File

@ -0,0 +1,71 @@
---
name: how-to-write-component
description: React/TypeScript component style guide. Use when writing, refactoring, or reviewing React components, especially around props typing, state boundaries, shared local state with Jotai atoms, API types, query/mutation contracts, navigation, memoization, wrappers, and empty-state handling.
---
# How To Write A Component
Use this as the decision guide for React/TypeScript component structure. Existing code is reference material, not automatic precedent; when it conflicts with these rules, adapt the approach instead of reproducing the violation.
## Core Defaults
- Search before adding UI, hooks, helpers, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit.
- Group code by feature workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them.
- Promote code to shared only when multiple verticals need the same stable primitive. Otherwise keep it local and compose shared primitives inside the owning feature.
- Use Tailwind CSS v4.1+ rules via the `tailwind-css-rules` skill. Prefer v4 utilities, `gap`, `text-size/line-height`, `min-h-dvh`, and avoid deprecated utilities and `@apply`.
## Ownership
- Put local state, queries, mutations, handlers, and derived UI data in the lowest component that uses them. Extract a purpose-built owner component only when the logic has no natural home.
- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data. Do not hoist a query only because it is duplicated; TanStack Query handles deduplication and cache sharing.
- Hoist state, queries, or callbacks to a parent only when the parent consumes the data, coordinates shared loading/error/empty UI, needs one consistent snapshot, or owns a workflow spanning children.
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow.
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action.
- Prefer uncontrolled DOM state and CSS variables before adding controlled props.
## Components, Props, And Types
- Type component signatures directly; do not use `FC` or `React.FC`.
- Prefer `function` for top-level components and module helpers. Use arrow functions for local callbacks, handlers, and lambda-style APIs.
- Prefer named exports. Use default exports only where the framework requires them, such as Next.js route files.
- Type simple one-off props inline. Use a named `Props` type only when reused, exported, complex, or clearer.
- Use API-generated or API-returned types at component boundaries. Keep small UI conversion helpers beside the component that needs them.
- Name values by their domain role and backend API contract, and keep that name stable across the call chain, especially IDs like `appInstanceId`. Normalize framework or route params at the boundary.
- Keep fallback and invariant checks at the lowest component that already handles that state; callers should pass raw values through instead of duplicating checks.
## Queries And Mutations
- Keep `web/contract/*` as the single source of truth for API shape; follow existing domain/router patterns and the `{ params, query?, body? }` input shape.
- Consume queries directly with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`.
- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename `queryOptions()` or `mutationOptions()`. Extract a small `queryOptions` helper only when repeated call-site options justify it.
- Keep feature hooks for real orchestration, workflow state, or shared domain behavior.
- For missing required query input, use `input: skipToken`; use `enabled` only for extra business gating after the input is valid.
- Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))`; use oRPC clients as `mutationFn` only for custom flows.
- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`; components may add UI feedback callbacks, but should not own shared invalidation rules.
- Do not use deprecated `useInvalid` or `useReset`.
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required, and wrap awaited calls in `try/catch`.
## Component Boundaries
- Use the first level below a page or tab to organize independent page sections when it adds real structure. This layer is layout/semantic first, not automatically the data owner.
- Split deeper components by the data and state each layer actually needs. Each component should access only necessary data, and ownership should stay at the lowest consumer.
- Keep cohesive forms, menu bodies, and one-off helpers local unless they need their own state, reuse, or semantic boundary.
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow.
- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment.
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
- Avoid shallow wrappers and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary.
## You Might Not Need An Effect
- Use Effects only to synchronize with external systems such as browser APIs, non-React widgets, subscriptions, timers, analytics that must run because the component was shown, or imperative DOM integration.
- Do not use Effects to transform props or state for rendering. Calculate derived values during render, and use `useMemo` only when the calculation is actually expensive.
- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known.
- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render.
- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary.
- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components.
- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow.
## Navigation And Performance
- Prefer `Link` for normal navigation. Use router APIs only for command-flow side effects such as mutation success, guarded redirects, or form submission.
- Avoid `memo`, `useMemo`, and `useCallback` unless there is a clear performance reason.

View File

@ -0,0 +1,367 @@
---
name: tailwind-css-rules
description: Tailwind CSS v4.1+ rules and best practices. Use when writing, reviewing, refactoring, or upgrading Tailwind CSS classes and styles, especially v4 utility migrations, layout spacing, typography, responsive variants, dark mode, gradients, CSS variables, and component styling.
---
# Tailwind CSS Rules and Best Practices
## Core Principles
- **Always use Tailwind CSS v4.1+** - Ensure the codebase is using the latest version
- **Do not use deprecated or removed utilities** - ALWAYS use the replacement
- **Never use `@apply`** - Use CSS variables, the `--spacing()` function, or framework components instead
- **Check for redundant classes** - Remove any classes that aren't necessary
- **Group elements logically** to simplify responsive tweaks later
## Upgrading to Tailwind CSS v4
### Before Upgrading
- **Always read the upgrade documentation first** - Read https://tailwindcss.com/docs/upgrade-guide and https://tailwindcss.com/blog/tailwindcss-v4 before starting an upgrade.
- Ensure the git repository is in a clean state before starting
### Upgrade Process
1. Run the upgrade command: `npx @tailwindcss/upgrade@latest` for both major and minor updates
2. The tool will convert JavaScript config files to the new CSS format
3. Review all changes extensively to clean up any false positives
4. Test thoroughly across your application
## Breaking Changes Reference
### Removed Utilities (NEVER use these in v4)
| ❌ Deprecated | ✅ Replacement |
| ----------------------- | ------------------------------------------------- |
| `bg-opacity-*` | Use opacity modifiers like `bg-black/50` |
| `text-opacity-*` | Use opacity modifiers like `text-black/50` |
| `border-opacity-*` | Use opacity modifiers like `border-black/50` |
| `divide-opacity-*` | Use opacity modifiers like `divide-black/50` |
| `ring-opacity-*` | Use opacity modifiers like `ring-black/50` |
| `placeholder-opacity-*` | Use opacity modifiers like `placeholder-black/50` |
| `flex-shrink-*` | `shrink-*` |
| `flex-grow-*` | `grow-*` |
| `overflow-ellipsis` | `text-ellipsis` |
| `decoration-slice` | `box-decoration-slice` |
| `decoration-clone` | `box-decoration-clone` |
### Renamed Utilities
Use the v4 name when migrating code that still carries Tailwind v3 semantics. Do not blanket-replace existing v4 classes: classes such as `rounded-sm`, `shadow-sm`, `ring-1`, and `ring-2` are valid in this codebase when they intentionally represent the current design scale.
| ❌ v3 pattern | ✅ v4 pattern |
| ------------------- | -------------------------------------------------- |
| `bg-gradient-*` | `bg-linear-*` |
| old shadow scale | verify against the current Tailwind/design scale |
| old blur scale | verify against the current Tailwind/design scale |
| old radius scale | use the Dify radius token mapping when applicable |
| `outline-none` | `outline-hidden` |
| bare `ring` utility | use an explicit ring width such as `ring-1`/`ring-2`/`ring-3` |
For Figma radius tokens, follow `packages/dify-ui/AGENTS.md`. For example, `--radius/xs` maps to `rounded-sm`; do not rewrite it to `rounded-xs`.
## Layout and Spacing Rules
### Flexbox and Grid Spacing
#### Always use gap utilities for internal spacing
Gap provides consistent spacing without edge cases (no extra space on last items). It's cleaner and more maintainable than margins on children.
```html
<!-- ❌ Don't do this -->
<div class="flex">
<div class="mr-4">Item 1</div>
<div class="mr-4">Item 2</div>
<div>Item 3</div>
<!-- No margin on last -->
</div>
<!-- ✅ Do this instead -->
<div class="flex gap-4">
<div>Item 1</div>
<div>Item 2</div>
<div>Item 3</div>
</div>
```
#### Gap vs Space utilities
- **Never use `space-x-*` or `space-y-*` in flex/grid layouts** - always use gap
- Space utilities add margins to children and have issues with wrapped items
- Gap works correctly with flex-wrap and all flex directions
```html
<!-- ❌ Avoid space utilities in flex containers -->
<div class="flex flex-wrap space-x-4">
<!-- Space utilities break with wrapped items -->
</div>
<!-- ✅ Use gap for consistent spacing -->
<div class="flex flex-wrap gap-4">
<!-- Gap works perfectly with wrapping -->
</div>
```
### General Spacing Guidelines
- **Prefer top and left margins** over bottom and right margins (unless conditionally rendered)
- **Use padding on parent containers** instead of bottom margins on the last child
- **Always use `min-h-dvh` instead of `min-h-screen`** - `min-h-screen` is buggy on mobile Safari
- **Prefer `size-*` utilities** over separate `w-*` and `h-*` when setting equal dimensions
- For max-widths, prefer the container scale (e.g., `max-w-2xs` over `max-w-72`)
## Typography Rules
### Line Heights
- **Never use `leading-*` classes** - Always use line height modifiers with text size
- **Always use fixed line heights from the spacing scale** - Don't use named values
```html
<!-- ❌ Don't do this -->
<p class="text-base leading-7">Text with separate line height</p>
<p class="text-lg leading-relaxed">Text with named line height</p>
<!-- ✅ Do this instead -->
<p class="text-base/7">Text with line height modifier</p>
<p class="text-lg/8">Text with specific line height</p>
```
### Font Size Reference
Be precise with font sizes - know the actual pixel values:
- `text-xs` = 12px
- `text-sm` = 14px
- `text-base` = 16px
- `text-lg` = 18px
- `text-xl` = 20px
## Color and Opacity
### Opacity Modifiers
**Never use `bg-opacity-*`, `text-opacity-*`, etc.** - use the opacity modifier syntax:
```html
<!-- ❌ Don't do this -->
<div class="bg-red-500 bg-opacity-60">Old opacity syntax</div>
<!-- ✅ Do this instead -->
<div class="bg-red-500/60">Modern opacity syntax</div>
```
## Responsive Design
### Breakpoint Optimization
- **Check for redundant classes across breakpoints**
- **Only add breakpoint variants when values change**
```html
<!-- ❌ Redundant breakpoint classes -->
<div class="px-4 md:px-4 lg:px-4">
<!-- md:px-4 and lg:px-4 are redundant -->
</div>
<!-- ✅ Efficient breakpoint usage -->
<div class="px-4 lg:px-8">
<!-- Only specify when value changes -->
</div>
```
## Dark Mode
### Dark Mode Best Practices
- Use the plain `dark:` variant pattern
- Put light mode styles first, then dark mode styles
- Ensure `dark:` variant comes before other variants
```html
<!-- ✅ Correct dark mode pattern -->
<div class="bg-white text-black dark:bg-black dark:text-white">
<button class="hover:bg-gray-100 dark:hover:bg-gray-800">Click me</button>
</div>
```
## Gradient Utilities
- **ALWAYS Use `bg-linear-*` instead of `bg-gradient-*` utilities** - The gradient utilities were renamed in v4
- Use the new `bg-radial` or `bg-radial-[<position>]` to create radial gradients
- Use the new `bg-conic` or `bg-conic-*` to create conic gradients
```html
<!-- ✅ Use the new gradient utilities -->
<div class="h-14 bg-linear-to-br from-violet-500 to-fuchsia-500"></div>
<div
class="size-18 bg-radial-[at_50%_75%] from-sky-200 via-blue-400 to-indigo-900 to-90%"
></div>
<div
class="size-24 bg-conic-180 from-indigo-600 via-indigo-50 to-indigo-600"
></div>
<!-- ❌ Do not use bg-gradient-* utilities -->
<div class="h-14 bg-gradient-to-br from-violet-500 to-fuchsia-500"></div>
```
## Working with CSS Variables
### Accessing Theme Values
Tailwind CSS v4 exposes all theme values as CSS variables:
```css
/* Access colors, and other theme values */
.custom-element {
background: var(--color-red-500);
border-radius: var(--radius-lg);
}
```
### The `--spacing()` Function
Use the dedicated `--spacing()` function for spacing calculations:
```css
.custom-class {
margin-top: calc(100vh - --spacing(16));
}
```
### Extending theme values
Use CSS to extend theme values:
```css
@import "tailwindcss";
@theme {
--color-mint-500: oklch(0.72 0.11 178);
}
```
```html
<div class="bg-mint-500">
<!-- ... -->
</div>
```
## New v4 Features
### Container Queries
Use the `@container` class and size variants:
```html
<article class="@container">
<div class="flex flex-col @md:flex-row @lg:gap-8">
<img class="w-full @md:w-48" />
<div class="mt-4 @md:mt-0">
<!-- Content adapts to container size -->
</div>
</div>
</article>
```
### Container Query Units
Use container-based units like `cqw` for responsive sizing:
```html
<div class="@container">
<h1 class="text-[50cqw]">Responsive to container width</h1>
</div>
```
### Text Shadows (v4.1)
Use text-shadow-\* utilities from text-shadow-2xs to text-shadow-lg:
```html
<!-- ✅ Text shadow examples -->
<h1 class="text-shadow-lg">Large shadow</h1>
<p class="text-shadow-sm/50">Small shadow with opacity</p>
```
### Masking (v4.1)
Use the new composable mask utilities for image and gradient masks:
```html
<!-- ✅ Linear gradient masks on specific sides -->
<div class="mask-t-from-50%">Top fade</div>
<div class="mask-b-from-20% mask-b-to-80%">Bottom gradient</div>
<div class="mask-linear-from-white mask-linear-to-black/60">
Fade from white to black
</div>
<!-- ✅ Radial gradient masks -->
<div class="mask-radial-[100%_100%] mask-radial-from-75% mask-radial-at-left">
Radial mask
</div>
```
## Component Patterns
### Avoiding Utility Inheritance
Don't add utilities to parents that you override in children:
```html
<!-- ❌ Avoid this pattern -->
<div class="text-center">
<h1>Centered Heading</h1>
<div class="text-left">Left-aligned content</div>
</div>
<!-- ✅ Better approach -->
<div>
<h1 class="text-center">Centered Heading</h1>
<div>Left-aligned content</div>
</div>
```
### Component Extraction
- Extract repeated patterns into framework components, not CSS classes
- Keep utility classes in templates/JSX
- Use data attributes for complex state-based styling
## CSS Best Practices
### Nesting Guidelines
- Use nesting when styling both parent and children
- Avoid empty parent selectors
```css
/* ✅ Good nesting - parent has styles */
.card {
padding: --spacing(4);
> .card-title {
font-weight: bold;
}
}
/* ❌ Avoid empty parents */
ul {
> li {
/* Parent has no styles */
}
}
```
## Common Pitfalls to Avoid
1. **Using old opacity utilities** - Always use `/opacity` syntax like `bg-red-500/60`
2. **Redundant breakpoint classes** - Only specify changes
3. **Space utilities in flex/grid** - Always use gap
4. **Leading utilities** - Use line-height modifiers like `text-sm/6`
5. **Arbitrary values** - Use the design scale
6. **@apply directive** - Use components or CSS variables
7. **min-h-screen on mobile** - Use min-h-dvh
8. **Separate width/height** - Use size utilities when equal
9. **Arbitrary values** - Always use Tailwind's predefined scale whenever possible (e.g., use `ml-4` over `ml-[16px]`)

View File

@ -9,6 +9,6 @@ jobs:
pull-requests: write
runs-on: depot-ubuntu-24.04
steps:
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
- uses: actions/labeler@f27b608878404679385c85cfa523b85ccb86e213 # v6.1.0
with:
sync-labels: true

View File

@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111
uses: anthropics/claude-code-action@476e359e6203e73dad705c8b322e333fabbd7416 # v1.0.119
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -3,6 +3,10 @@ DOCKER_REGISTRY=langgenius
WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
VERSION=latest
DOCKER_DIR=docker
DOCKER_MIDDLEWARE_ENV=$(DOCKER_DIR)/middleware.env
DOCKER_MIDDLEWARE_ENV_EXAMPLE=$(DOCKER_DIR)/envs/middleware.env.example
DOCKER_MIDDLEWARE_PROJECT=dify-middlewares-dev
# Default target - show help
.DEFAULT_GOAL := help
@ -17,8 +21,13 @@ dev-setup: prepare-docker prepare-web prepare-api
# Step 1: Prepare Docker middleware
prepare-docker:
@echo "🐳 Setting up Docker middleware..."
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
@if [ ! -f "$(DOCKER_MIDDLEWARE_ENV)" ]; then \
cp "$(DOCKER_MIDDLEWARE_ENV_EXAMPLE)" "$(DOCKER_MIDDLEWARE_ENV)"; \
echo "Docker middleware.env created"; \
else \
echo "Docker middleware.env already exists"; \
fi
@cd $(DOCKER_DIR) && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p $(DOCKER_MIDDLEWARE_PROJECT) up -d
@echo "✅ Docker middleware started"
# Step 2: Prepare web environment
@ -39,12 +48,18 @@ prepare-api:
# Clean dev environment
dev-clean:
@echo "⚠️ Stopping Docker containers..."
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
@if [ -f "$(DOCKER_MIDDLEWARE_ENV)" ]; then \
cd $(DOCKER_DIR) && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p $(DOCKER_MIDDLEWARE_PROJECT) down; \
else \
echo "Docker middleware.env does not exist, skipping compose down"; \
fi
@echo "🗑️ Removing volumes..."
@rm -rf docker/volumes/db
@rm -rf docker/volumes/mysql
@rm -rf docker/volumes/redis
@rm -rf docker/volumes/plugin_daemon
@rm -rf docker/volumes/weaviate
@rm -rf docker/volumes/sandbox/dependencies
@rm -rf api/storage
@echo "✅ Cleanup complete"
@ -132,7 +147,7 @@ help:
@echo " make prepare-docker - Set up Docker middleware"
@echo " make prepare-web - Set up web environment"
@echo " make prepare-api - Set up API environment"
@echo " make dev-clean - Stop Docker middleware containers"
@echo " make dev-clean - Stop Docker middleware containers and remove dev data"
@echo ""
@echo "Backend Code Quality:"
@echo " make format - Format code with ruff"

View File

@ -34,7 +34,7 @@ TRIGGER_URL=http://localhost:5001
FILES_ACCESS_TIMEOUT=300
# Collaboration mode toggle
ENABLE_COLLABORATION_MODE=false
ENABLE_COLLABORATION_MODE=true
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
@ -88,6 +88,10 @@ REDIS_HEALTH_CHECK_INTERVAL=30
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis
# Ops trace retry configuration
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5
# Database configuration
DB_TYPE=postgresql
DB_USERNAME=postgres

View File

@ -181,7 +181,6 @@ def initialize_extensions(app: DifyApp):
ext_import_modules,
ext_orjson,
ext_forward_refs,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,
ext_database,
@ -189,6 +188,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
ext_set_secretkey,
ext_logstore, # Initialize logstore after storage, before celery
ext_celery,
ext_login,

View File

@ -23,9 +23,9 @@ class SecurityConfig(BaseSettings):
"""
SECRET_KEY: str = Field(
description="Secret key for secure session cookie signing."
"Make sure you are changing this key for your deployment with a strong key."
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
description="Secret key for secure session cookie signing. "
"Leave empty to let Dify generate a persistent key in the storage directory, "
"or set a strong value via the `SECRET_KEY` environment variable.",
default="",
)
@ -1137,6 +1137,18 @@ class MultiModalTransferConfig(BaseSettings):
)
class OpsTraceConfig(BaseSettings):
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES: PositiveInt = Field(
description="Maximum retry attempts for transient ops trace provider dispatch failures.",
default=60,
)
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS: PositiveInt = Field(
description="Delay in seconds between transient ops trace provider dispatch retry attempts.",
default=5,
)
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
description="Interval in days for Celery Beat scheduler execution, default to 1 day",
@ -1298,7 +1310,7 @@ class PositionConfig(BaseSettings):
class CollaborationConfig(BaseSettings):
ENABLE_COLLABORATION_MODE: bool = Field(
description="Whether to enable collaboration mode features across the workspace",
default=False,
default=True,
)
@ -1417,6 +1429,7 @@ class FeatureConfig(
ModelLoadBalanceConfig,
ModerationConfig,
MultiModalTransferConfig,
OpsTraceConfig,
PositionConfig,
RagEtlConfig,
RepositoryConfig,

38
api/configs/secret_key.py Normal file
View File

@ -0,0 +1,38 @@
"""SECRET_KEY persistence helpers for runtime setup."""
from __future__ import annotations
import secrets
from extensions.ext_storage import storage
GENERATED_SECRET_KEY_FILENAME = ".dify_secret_key"
def resolve_secret_key(secret_key: str) -> str:
"""Return an explicit SECRET_KEY or a generated key persisted in storage."""
if secret_key:
return secret_key
return _load_or_create_secret_key()
def _load_or_create_secret_key() -> str:
try:
persisted_key = storage.load_once(GENERATED_SECRET_KEY_FILENAME).decode("utf-8").strip()
if persisted_key:
return persisted_key
except FileNotFoundError:
pass
generated_key = secrets.token_urlsafe(48)
try:
storage.save(GENERATED_SECRET_KEY_FILENAME, f"{generated_key}\n".encode())
except Exception as exc:
raise ValueError(
f"SECRET_KEY is not set and could not be generated at {GENERATED_SECRET_KEY_FILENAME}. "
"Set SECRET_KEY explicitly or make storage writable."
) from exc
return generated_key

View File

@ -39,7 +39,7 @@ from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from services.app_dsl_service import AppDslService
from services.app_service import AppService
from services.app_service import AppListParams, AppService, CreateAppParams
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import ImportMode, ImportStatus
from services.entities.knowledge_entities.knowledge_entities import (
@ -478,11 +478,18 @@ class AppListApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
args_dict = args.model_dump()
params = AppListParams(
page=args.page,
limit=args.limit,
mode=args.mode,
name=args.name,
tag_ids=args.tag_ids,
is_created_by_me=args.is_created_by_me,
)
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
if not app_pagination:
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
return empty.model_dump(mode="json"), 200
@ -546,9 +553,17 @@ class AppListApi(Resource):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
args = CreateAppPayload.model_validate(console_ns.payload)
params = CreateAppParams(
name=args.name,
description=args.description,
mode=args.mode,
icon_type=args.icon_type,
icon=args.icon,
icon_background=args.icon_background,
)
app_service = AppService()
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
app = app_service.create_app(current_tenant_id, params, current_user)
app_detail = AppDetail.model_validate(app, from_attributes=True)
return app_detail.model_dump(mode="json"), 201

View File

@ -606,63 +606,63 @@ class DatasetIndexingEstimateApi(Resource):
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = db.session.scalars(
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
).all()
match args["info_list"]["data_source_type"]:
case "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = db.session.scalars(
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
).all()
if file_details is None:
raise NotFound("File not found.")
if file_details is None:
raise NotFound("File not found.")
if file_details:
for file_detail in file_details:
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
case "notion_import":
notion_info_list = args["info_list"]["notion_info_list"]
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
case "website_crawl":
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import":
notion_info_list = args["info_list"]["notion_info_list"]
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "website_crawl":
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
else:
raise ValueError("Data source type not support")
case _:
raise ValueError("Data source type not support")
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(

View File

@ -369,28 +369,31 @@ class DatasetDocumentListApi(Resource):
else:
sort_logic = asc
if sort == "hit_count":
sub_query = (
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.where(DocumentSegment.dataset_id == str(dataset_id))
.group_by(DocumentSegment.document_id)
.subquery()
)
match sort:
case "hit_count":
sub_query = (
sa.select(
DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")
)
.where(DocumentSegment.dataset_id == str(dataset_id))
.group_by(DocumentSegment.document_id)
.subquery()
)
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
elif sort == "created_at":
query = query.order_by(
sort_logic(Document.created_at),
sort_logic(Document.position),
)
else:
query = query.order_by(
desc(Document.created_at),
desc(Document.position),
)
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
case "created_at":
query = query.order_by(
sort_logic(Document.created_at),
sort_logic(Document.position),
)
case _:
query = query.order_by(
desc(Document.created_at),
desc(Document.position),
)
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items

View File

@ -106,7 +106,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
simple_account_model = get_or_create_model("TrialSimpleAccount", simple_account_fields)
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)

View File

@ -136,7 +136,7 @@ def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[
if not dataset:
raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args["indexing_technique"]:
if not dataset.indexing_technique and not args.get("indexing_technique"):
raise ValueError("indexing_technique is required.")
embedding_model_provider = payload.embedding_model_provider

View File

@ -32,7 +32,7 @@ from core.app.entities.task_entities import (
)
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_parent_trace_context_from_args
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
@ -166,6 +166,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
extras = {
**extract_external_trace_id_from_args(args),
**extract_parent_trace_context_from_args(args),
}
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args

View File

@ -128,7 +128,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
@staticmethod
def _secret_key() -> bytes:
return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
return dify_config.SECRET_KEY.encode()
def _sign_query(self, *, payload: str) -> dict[str, str]:
timestamp = str(int(time.time()))

View File

@ -15,6 +15,7 @@ from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.helper.trace_id_helper import ParentTraceContext
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
@ -403,8 +404,13 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
external_trace_id = None
parent_trace_context = None
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
extras = self._application_generate_entity.extras
external_trace_id = extras.get("external_trace_id")
parent_trace_context = extras.get("parent_trace_context")
if isinstance(parent_trace_context, ParentTraceContext):
parent_trace_context = parent_trace_context.model_dump(exclude_none=True)
trace_task = TraceTask(
TraceTaskName.WORKFLOW_TRACE,
@ -412,6 +418,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
conversation_id=conversation_id,
user_id=self._trace_manager.user_id,
external_trace_id=external_trace_id,
parent_trace_context=parent_trace_context,
)
self._trace_manager.add_trace_task(trace_task)

View File

@ -35,8 +35,11 @@ class DatasourceFileManager:
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
sign = hmac.new(
dify_config.SECRET_KEY.encode(),
data_to_sign.encode(),
hashlib.sha256,
).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@ -47,8 +50,11 @@ class DatasourceFileManager:
verify signature
"""
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_sign = hmac.new(
dify_config.SECRET_KEY.encode(),
data_to_sign.encode(),
hashlib.sha256,
).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature

View File

@ -3,6 +3,17 @@ import re
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, ConfigDict, StrictStr, ValidationError
class ParentTraceContext(BaseModel):
"""Typed parent trace context propagated from an outer workflow tool node."""
parent_workflow_run_id: StrictStr
parent_node_execution_id: StrictStr | None = None
model_config = ConfigDict(extra="forbid")
def is_valid_trace_id(trace_id: str) -> bool:
"""
@ -61,6 +72,30 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]):
return {}
def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str, ParentTraceContext]:
"""
Extract 'parent_trace_context' from args.
Returns a dict suitable for use in extras when both parent identifiers exist.
Returns an empty dict if the context is missing or incomplete.
"""
parent_trace_context = args.get("parent_trace_context")
if isinstance(parent_trace_context, ParentTraceContext):
context = parent_trace_context
elif isinstance(parent_trace_context, Mapping):
try:
context = ParentTraceContext.model_validate(parent_trace_context)
except ValidationError:
return {}
else:
return {}
if context.parent_node_execution_id is None:
return {}
return {"parent_trace_context": context}
def get_trace_id_from_otel_context() -> str | None:
"""
Retrieve the current trace ID from the active OpenTelemetry trace context.

View File

@ -324,9 +324,10 @@ class IndexingRunner:
# one extract_setting is one source document
for extract_setting in extract_settings:
# extract
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
)
processing_rule = {
"mode": tmp_processing_rule["mode"],
"rules": tmp_processing_rule.get("rules"),
}
# Extract document content
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
# Cleaning and segmentation
@ -334,7 +335,7 @@ class IndexingRunner:
text_docs,
current_user=None,
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule.to_dict(),
process_rule=processing_rule,
tenant_id=tenant_id,
doc_language=doc_language,
preview=True,

View File

@ -5,6 +5,8 @@ from typing import Any, Union
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
from core.helper.trace_id_helper import ParentTraceContext
class BaseTraceInfo(BaseModel):
message_id: str | None = None
@ -51,8 +53,8 @@ class BaseTraceInfo(BaseModel):
def resolved_parent_context(self) -> tuple[str | None, str | None]:
"""Resolve cross-workflow parent linking from metadata.
Extracts typed parent IDs from the untyped ``parent_trace_context``
metadata dict (set by tool_node when invoking nested workflows).
Extracts typed parent IDs from the ``parent_trace_context`` metadata
payload (set by tool_node when invoking nested workflows).
Returns:
(trace_correlation_override, parent_span_id_source) where
@ -60,13 +62,18 @@ class BaseTraceInfo(BaseModel):
parent_span_id_source is the outer node_execution_id.
"""
parent_ctx = self.metadata.get("parent_trace_context")
if not isinstance(parent_ctx, dict):
if isinstance(parent_ctx, ParentTraceContext):
context = parent_ctx
elif isinstance(parent_ctx, Mapping):
try:
context = ParentTraceContext.model_validate(parent_ctx)
except ValueError:
return None, None
else:
return None, None
trace_override = parent_ctx.get("parent_workflow_run_id")
parent_span = parent_ctx.get("parent_node_execution_id")
return (
trace_override if isinstance(trace_override, str) else None,
parent_span if isinstance(parent_span, str) else None,
context.parent_workflow_run_id,
context.parent_node_execution_id,
)
@field_serializer("start_time", "end_time")

View File

@ -0,0 +1,22 @@
"""Core exceptions shared by ops trace dispatchers and trace providers.
Provider packages may raise these types to request generic task behavior, but
generic Celery tasks should not import provider-specific exception classes.
"""
class RetryableTraceDispatchError(RuntimeError):
"""Base class for transient trace dispatch failures that Celery may retry."""
class PendingTraceParentContextError(RetryableTraceDispatchError):
"""Raised when a nested trace arrives before its parent span context is available."""
parent_node_execution_id: str
def __init__(self, parent_node_execution_id: str) -> None:
self.parent_node_execution_id = parent_node_execution_id
super().__init__(
"Pending trace parent context for parent_node_execution_id="
f"{parent_node_execution_id}. Retry after the parent span context is published."
)

View File

@ -16,6 +16,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.helper.trace_id_helper import ParentTraceContext
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
BaseTracingConfig,
@ -52,6 +53,17 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _dump_parent_trace_context(parent_trace_context: Any) -> dict[str, str] | None:
if isinstance(parent_trace_context, ParentTraceContext):
return parent_trace_context.model_dump(exclude_none=True)
if isinstance(parent_trace_context, dict):
try:
return ParentTraceContext.model_validate(parent_trace_context).model_dump(exclude_none=True)
except ValueError:
return None
return None
class _AppTracingConfig(TypedDict, total=False):
enabled: bool
tracing_provider: str | None
@ -857,8 +869,9 @@ class TraceTask:
}
parent_trace_context = self.kwargs.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context)
if dumped_parent_trace_context:
metadata["parent_trace_context"] = dumped_parent_trace_context
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
@ -1371,13 +1384,14 @@ class TraceTask:
}
parent_trace_context = node_data.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context)
if dumped_parent_trace_context:
metadata["parent_trace_context"] = dumped_parent_trace_context
message_id: str | None = None
conversation_id = node_data.get("conversation_id")
workflow_execution_id = node_data.get("workflow_execution_id")
if conversation_id and workflow_execution_id and not parent_trace_context:
if conversation_id and workflow_execution_id and not dumped_parent_trace_context:
with Session(db.engine) as session:
msg_id = session.scalar(
select(Message.id).where(

View File

@ -123,12 +123,15 @@ class SimplePromptTransform(PromptTransform):
for v in special_variable_keys:
# support #context#, #query# and #histories#
if v == "#context#":
variables["#context#"] = context or ""
elif v == "#query#":
variables["#query#"] = query or ""
elif v == "#histories#":
variables["#histories#"] = histories or ""
match v:
case "#context#":
variables["#context#"] = context or ""
case "#query#":
variables["#query#"] = query or ""
case "#histories#":
variables["#histories#"] = histories or ""
case _:
pass
prompt_template = prompt_template_config["prompt_template"]
if not isinstance(prompt_template, PromptTemplateParser):

View File

@ -245,6 +245,7 @@ class Jieba(BaseKeyword):
segment = pre_segment_data["segment"]
if pre_segment_data["keywords"]:
segment.keywords = pre_segment_data["keywords"]
assert segment.index_node_id
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
@ -253,6 +254,7 @@ class Jieba(BaseKeyword):
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
segment.keywords = list(keywords)
assert segment.index_node_id
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, segment.index_node_id, list(keywords)
)

View File

@ -1,5 +1,6 @@
import concurrent.futures
import logging
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, NotRequired, TypedDict
@ -526,7 +527,7 @@ class RetrievalService:
index_node_ids = [i for i in index_node_ids if i]
segment_ids: list[str] = []
index_node_segments: list[DocumentSegment] = []
index_node_segments: Sequence[DocumentSegment] = []
segments: list[DocumentSegment] = []
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
@ -568,8 +569,9 @@ class RetrievalService:
DocumentSegment.status == "completed",
DocumentSegment.index_node_id.in_(index_node_ids),
)
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
index_node_segments = session.execute(document_segment_stmt).scalars().all()
for index_node_segment in index_node_segments:
assert index_node_segment.index_node_id
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
if segment_ids:

View File

@ -50,6 +50,7 @@ class DatasetDocumentStore:
output = {}
for document_segment in document_segments:
assert document_segment.index_node_id
doc_id = document_segment.index_node_id
output[doc_id] = Document(
page_content=document_segment.content,
@ -103,7 +104,7 @@ class DatasetDocumentStore:
if not segment_document:
max_position += 1
assert self._document_id
segment_document = DocumentSegment(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,

View File

@ -84,7 +84,7 @@ class IndexProcessor:
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
indexing_start_at = time.perf_counter()
# delete from vector index

View File

@ -29,6 +29,7 @@ from libs import helper
from models import Account
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import ProcessRuleMode
from services.account_service import AccountService
from services.summary_index_service import SummaryIndexService
@ -325,7 +326,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
# update document parent mode
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode="hierarchical",
mode=ProcessRuleMode.HIERARCHICAL,
rules=json.dumps(
{
"parent_mode": parent_childs.parent_mode,

View File

@ -8,6 +8,10 @@ import urllib.parse
from configs import dify_config
def _secret_key() -> bytes:
return dify_config.SECRET_KEY.encode()
def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
"""
sign file to get a temporary url for plugin access
@ -19,8 +23,7 @@ def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True)
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@ -39,8 +42,7 @@ def sign_upload_file_preview_url(upload_file_id: str, extension: str) -> str:
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@ -51,8 +53,7 @@ def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: s
verify signature
"""
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
@ -71,8 +72,7 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str,
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
query = urllib.parse.urlencode(
{
@ -92,8 +92,7 @@ def verify_plugin_file_signature(
"""Verify the signature used by the plugin-facing file upload endpoint."""
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
if sign != recalculated_encoded_sign:

View File

@ -51,8 +51,11 @@ class ToolFileManager:
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
sign = hmac.new(
dify_config.SECRET_KEY.encode(),
data_to_sign.encode(),
hashlib.sha256,
).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@ -63,8 +66,11 @@ class ToolFileManager:
verify signature
"""
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_sign = hmac.new(
dify_config.SECRET_KEY.encode(),
data_to_sign.encode(),
hashlib.sha256,
).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature

View File

@ -9,6 +9,7 @@ from sqlalchemy import select
from core.app.file_access import DatabaseFileAccessController
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import ParentTraceContext, extract_parent_trace_context_from_args
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
@ -36,6 +37,8 @@ class WorkflowTool(Tool):
Workflow tool.
"""
_parent_trace_context: ParentTraceContext | None
def __init__(
self,
workflow_app_id: str,
@ -54,6 +57,7 @@ class WorkflowTool(Tool):
self.workflow_call_depth = workflow_call_depth
self.label = label
self._latest_usage = LLMUsage.empty_usage()
self._parent_trace_context = None
super().__init__(entity=entity, runtime=runtime)
@ -94,11 +98,17 @@ class WorkflowTool(Tool):
self._latest_usage = LLMUsage.empty_usage()
generator_args: dict[str, Any] = {"inputs": tool_parameters, "files": files}
if self._parent_trace_context:
generator_args.update(
extract_parent_trace_context_from_args({"parent_trace_context": self._parent_trace_context})
)
result = generator.generate(
app_model=app,
workflow=workflow,
user=user,
args={"inputs": tool_parameters, "files": files},
args=generator_args,
invoke_from=self.runtime.invoke_from,
streaming=False,
call_depth=self.workflow_call_depth + 1,
@ -194,7 +204,7 @@ class WorkflowTool(Tool):
:return: the new tool
"""
return self.__class__(
forked = self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
workflow_app_id=self.workflow_app_id,
@ -204,6 +214,24 @@ class WorkflowTool(Tool):
version=self.version,
label=self.label,
)
forked._parent_trace_context = self._parent_trace_context.model_copy() if self._parent_trace_context else None
return forked
def set_parent_trace_context(
self,
*,
parent_workflow_run_id: str,
parent_node_execution_id: str,
) -> None:
"""Attach outer workflow trace context without exposing it as tool input."""
self._parent_trace_context = ParentTraceContext(
parent_workflow_run_id=parent_workflow_run_id,
parent_node_execution_id=parent_node_execution_id,
)
def clear_parent_trace_context(self) -> None:
"""Remove parent trace context before invoking this tool outside a nested workflow."""
self._parent_trace_context = None
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
"""

View File

@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.file_access import DatabaseFileAccessController
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.helper.trace_id_helper import ParentTraceContext
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelInstance
@ -358,6 +359,7 @@ class _WorkflowToolRuntimeBinding:
tool: Tool
conversation_id: str | None = None
parent_trace_context: ParentTraceContext | None = None
class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
@ -378,6 +380,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
node_id: str,
node_data: ToolNodeData,
variable_pool,
node_execution_id: str | None = None,
) -> ToolRuntimeHandle:
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(
@ -397,7 +400,25 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
conversation_id = (
None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID)
)
return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id))
parent_trace_context: ParentTraceContext | None = None
if self._is_workflow_tool_provider(node_data):
outer_workflow_run_id = (
None
if variable_pool is None
else get_system_text(variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID)
)
if isinstance(outer_workflow_run_id, str) and isinstance(node_execution_id, str):
parent_trace_context = ParentTraceContext(
parent_workflow_run_id=outer_workflow_run_id,
parent_node_execution_id=node_execution_id,
)
return ToolRuntimeHandle(
raw=_WorkflowToolRuntimeBinding(
tool=tool_runtime,
conversation_id=conversation_id,
parent_trace_context=parent_trace_context,
)
)
def get_runtime_parameters(
self,
@ -421,6 +442,13 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
runtime_binding = self._binding_from_handle(tool_runtime)
tool = runtime_binding.tool
callback = DifyWorkflowCallbackHandler()
if runtime_binding.parent_trace_context and hasattr(tool, "set_parent_trace_context"):
tool.set_parent_trace_context(
parent_workflow_run_id=runtime_binding.parent_trace_context.parent_workflow_run_id,
parent_node_execution_id=runtime_binding.parent_trace_context.parent_node_execution_id,
)
elif hasattr(tool, "clear_parent_trace_context"):
tool.clear_parent_trace_context()
try:
messages = ToolEngine.generic_invoke(
@ -513,6 +541,10 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
credential_id=node_data.credential_id,
)
@staticmethod
def _is_workflow_tool_provider(node_data: ToolNodeData) -> bool:
return node_data.provider_type.value == CoreToolProviderType.WORKFLOW.value
def _adapt_messages(
self,
messages: Generator[CoreToolInvokeMessage, None, None],

View File

@ -5,6 +5,7 @@ import threading
from flask import Response
from configs import dify_config
from controllers.console.admin import admin_required
from dify_app import DifyApp
@ -25,6 +26,7 @@ def init_app(app: DifyApp):
)
@app.route("/threads")
@admin_required
def threads(): # pyright: ignore[reportUnusedFunction]
num_threads = threading.active_count()
threads = threading.enumerate()
@ -50,6 +52,7 @@ def init_app(app: DifyApp):
}
@app.route("/db-pool-stat")
@admin_required
def pool_stat(): # pyright: ignore[reportUnusedFunction]
from extensions.ext_database import db

View File

@ -1,6 +1,13 @@
from configs import dify_config
from configs.secret_key import resolve_secret_key
from dify_app import DifyApp
def init_app(app: DifyApp):
app.secret_key = dify_config.SECRET_KEY
def init_app(app: DifyApp) -> None:
"""Resolve SECRET_KEY after config loading and before session/login setup."""
secret_key = dify_config.SECRET_KEY
if not secret_key:
secret_key = resolve_secret_key(secret_key)
dify_config.SECRET_KEY = secret_key
app.config["SECRET_KEY"] = secret_key
app.secret_key = secret_key

View File

@ -1,19 +1,22 @@
"""Workflow comment models."""
from __future__ import annotations
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import Index, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from models.base import TypeBase
from .account import Account
from .base import Base, gen_uuidv7_string
from .base import gen_uuidv7_string
from .engine import db
from .types import StringUUID
class WorkflowComment(Base):
class WorkflowComment(TypeBase):
"""Workflow comment model for canvas commenting functionality.
Comments are associated with apps rather than specific workflow versions,
@ -42,27 +45,33 @@ class WorkflowComment(Base):
Index("workflow_comments_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position_x: Mapped[float] = mapped_column(sa.Float)
position_y: Mapped[float] = mapped_column(sa.Float)
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime, default=None)
resolved_by: Mapped[str | None] = mapped_column(StringUUID, default=None)
resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
# Relationships
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
replies: Mapped[list[WorkflowCommentReply]] = relationship(
lambda: WorkflowCommentReply, back_populates="comment", cascade="all, delete-orphan", init=False
)
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
mentions: Mapped[list[WorkflowCommentMention]] = relationship(
lambda: WorkflowCommentMention, back_populates="comment", cascade="all, delete-orphan", init=False
)
@property
@ -131,7 +140,7 @@ class WorkflowComment(Base):
return participants
class WorkflowCommentReply(Base):
class WorkflowCommentReply(TypeBase):
"""Workflow comment reply model.
Attributes:
@ -149,18 +158,24 @@ class WorkflowCommentReply(Base):
Index("comment_replies_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
comment_id: Mapped[str] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
sa.DateTime,
nullable=False,
server_default=func.current_timestamp(),
onupdate=func.current_timestamp(),
init=False,
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="replies", init=False)
@property
def created_by_account(self):
@ -174,7 +189,7 @@ class WorkflowCommentReply(Base):
self._created_by_account_cache = account
class WorkflowCommentMention(Base):
class WorkflowCommentMention(TypeBase):
"""Workflow comment mention model.
Mentions are only for internal accounts since end users
@ -194,18 +209,18 @@ class WorkflowCommentMention(Base):
Index("comment_mentions_user_idx", "mentioned_user_id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
comment_id: Mapped[str] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
)
reply_id: Mapped[str | None] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
)
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
reply_id: Mapped[str | None] = mapped_column(
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True, default=None
)
# Relationships
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="mentions", init=False)
reply: Mapped[WorkflowCommentReply | None] = relationship(lambda: WorkflowCommentReply, init=False)
@property
def mentioned_user_account(self):

View File

@ -8,10 +8,9 @@ import os
import pickle
import re
import time
from collections.abc import Sequence
from datetime import datetime
from json import JSONDecodeError
from typing import Any, TypedDict, cast
from typing import Any, ClassVar, TypedDict, cast
from uuid import uuid4
import sqlalchemy as sa
@ -441,23 +440,27 @@ class Dataset(Base):
return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node"
class DatasetProcessRule(Base): # bug
class DatasetProcessRule(TypeBase):
__tablename__ = "dataset_process_rules"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
)
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
dataset_id = mapped_column(StringUUID, nullable=False)
mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'"))
rules = mapped_column(LongText, nullable=True)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
mode: Mapped[ProcessRuleMode] = mapped_column(
EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'")
)
rules: Mapped[str | None] = mapped_column(LongText, nullable=True)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
MODES = ["automatic", "custom", "hierarchical"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
AUTOMATIC_RULES: AutomaticRulesConfig = {
AUTOMATIC_RULES: ClassVar[AutomaticRulesConfig] = {
"pre_processing_rules": [
{"id": "remove_extra_spaces", "enabled": True},
{"id": "remove_urls_emails", "enabled": False},
@ -827,7 +830,7 @@ class Document(Base):
)
class DocumentSegment(Base):
class DocumentSegment(TypeBase):
__tablename__ = "document_segments"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_segment_pkey"),
@ -840,35 +843,40 @@ class DocumentSegment(Base):
)
# initial fields
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
tenant_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
document_id = mapped_column(StringUUID, nullable=False)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
position: Mapped[int]
content = mapped_column(LongText, nullable=False)
answer = mapped_column(LongText, nullable=True)
content: Mapped[str] = mapped_column(LongText, nullable=False)
word_count: Mapped[int]
tokens: Mapped[int]
# indexing fields
keywords = mapped_column(sa.JSON, nullable=True)
index_node_id = mapped_column(String(255), nullable=True)
index_node_hash = mapped_column(String(255), nullable=True)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
# basic fields
# indexing fields
index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
keywords: Mapped[Any] = mapped_column(sa.JSON, nullable=True, default=None)
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
status: Mapped[SegmentStatus] = mapped_column(
EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"), default=SegmentStatus.WAITING
)
created_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"))
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
error = mapped_column(LongText, nullable=True)
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
@property
def dataset(self):
@ -895,7 +903,7 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self) -> Sequence[Any]:
def child_chunks(self):
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -910,7 +918,7 @@ class DocumentSegment(Base):
return child_chunks or []
return []
def get_child_chunks(self) -> Sequence[Any]:
def get_child_chunks(self):
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@ -941,7 +949,7 @@ class DocumentSegment(Base):
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
secret_key = dify_config.SECRET_KEY.encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
@ -958,7 +966,7 @@ class DocumentSegment(Base):
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
secret_key = dify_config.SECRET_KEY.encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
@ -977,7 +985,7 @@ class DocumentSegment(Base):
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
secret_key = dify_config.SECRET_KEY.encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
@ -1015,7 +1023,7 @@ class DocumentSegment(Base):
nonce = os.urandom(16).hex()
timestamp = str(int(time.time()))
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
secret_key = dify_config.SECRET_KEY.encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()

View File

@ -13786,6 +13786,14 @@ Tag type
| unit | string | | No |
| variable | string | | No |
#### TrialSimpleAccount
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| email | string | | No |
| id | string | | No |
| name | string | | No |
#### TrialSite
| Name | Type | Description | Required |
@ -13829,7 +13837,7 @@ Tag type
| ---- | ---- | ----------- | -------- |
| conversation_variables | [ [TrialConversationVariable](#trialconversationvariable) ] | | No |
| created_at | object | | No |
| created_by | [SimpleAccount](#simpleaccount) | | No |
| created_by | [TrialSimpleAccount](#trialsimpleaccount) | | No |
| environment_variables | [ object ] | | No |
| features | object | | No |
| graph | object | | No |
@ -13840,7 +13848,7 @@ Tag type
| rag_pipeline_variables | [ [TrialPipelineVariable](#trialpipelinevariable) ] | | No |
| tool_published | boolean | | No |
| updated_at | object | | No |
| updated_by | [SimpleAccount](#simpleaccount) | | No |
| updated_by | [TrialSimpleAccount](#trialsimpleaccount) | | No |
| version | string | | No |
#### TrialWorkflowPartial

View File

@ -1,9 +1,11 @@
import json
import logging
import os
import re
import traceback
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from typing import Any, Union, cast
from typing import Any, Protocol, Union, cast
from urllib.parse import urlparse
from openinference.semconv.trace import (
@ -19,7 +21,7 @@ from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.semconv.attributes import exception_attributes
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
from opentelemetry.trace import Span, Status, StatusCode, get_current_span, set_span_in_context, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import AttributeValue
from sqlalchemy.orm import sessionmaker
@ -36,16 +38,106 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.exceptions import PendingTraceParentContextError
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import DifyCoreRepositoryFactory
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.enums import WorkflowNodeExecutionStatus
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
# This parent-span carrier store is intentionally Phoenix-local for the current
# nested workflow tracing feature. If other trace providers need the same
# cross-task parent restoration behavior, move the storage and retry signaling
# behind a core trace coordination interface instead of duplicating it here.
_PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS = 300
_TRACEPARENT_PATTERN = re.compile(
r"^(?P<version>[0-9a-f]{2})-(?P<trace_id>[0-9a-f]{32})-(?P<span_id>[0-9a-f]{16})-(?P<flags>[0-9a-f]{2})$"
)
def _phoenix_parent_span_redis_key(parent_node_execution_id: str) -> str:
"""Build the Redis key that stores a restorable Phoenix parent span carrier."""
return f"trace:phoenix:parent_span:{parent_node_execution_id}"
def _publish_parent_span_context(parent_node_execution_id: str, carrier: Mapping[str, str]) -> None:
"""Persist a tracecontext carrier so nested workflow spans can restore the tool span parent."""
redis_client.setex(
_phoenix_parent_span_redis_key(parent_node_execution_id),
_PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS,
safe_json_dumps(dict(carrier)),
)
def _resolve_published_parent_span_context(parent_node_execution_id: str) -> dict[str, str]:
"""Load a previously published tool-span carrier for nested workflow parenting."""
raw_carrier = redis_client.get(_phoenix_parent_span_redis_key(parent_node_execution_id))
if raw_carrier is None:
raise PendingTraceParentContextError(parent_node_execution_id)
if isinstance(raw_carrier, bytes):
raw_carrier = raw_carrier.decode("utf-8")
carrier = json.loads(raw_carrier)
if not isinstance(carrier, dict):
raise ValueError(
"Phoenix parent span context must be stored as a JSON object: "
f"parent_node_execution_id={parent_node_execution_id}"
)
normalized_carrier = {str(key): str(value) for key, value in carrier.items()}
if not normalized_carrier:
raise ValueError(
f"Phoenix parent span context payload is empty: parent_node_execution_id={parent_node_execution_id}"
)
traceparent = normalized_carrier.get("traceparent")
if not isinstance(traceparent, str):
raise ValueError(
"Phoenix parent span context payload is missing traceparent: "
f"parent_node_execution_id={parent_node_execution_id}"
)
traceparent_match = _TRACEPARENT_PATTERN.fullmatch(traceparent)
if traceparent_match is None:
raise ValueError(
"Phoenix parent span context payload has invalid traceparent format: "
f"parent_node_execution_id={parent_node_execution_id}"
)
if traceparent_match.group("version") == "ff":
raise ValueError(
"Phoenix parent span context payload has unsupported traceparent version: "
f"parent_node_execution_id={parent_node_execution_id}"
)
if traceparent_match.group("trace_id") == "0" * 32:
raise ValueError(
"Phoenix parent span context payload has zero trace_id in traceparent: "
f"parent_node_execution_id={parent_node_execution_id}"
)
if traceparent_match.group("span_id") == "0" * 16:
raise ValueError(
"Phoenix parent span context payload has zero span_id in traceparent: "
f"parent_node_execution_id={parent_node_execution_id}"
)
extracted_context = TraceContextTextMapPropagator().extract(carrier=normalized_carrier)
extracted_span_context = get_current_span(extracted_context).get_span_context()
if not extracted_span_context.is_valid or not extracted_span_context.is_remote:
raise ValueError(
"Phoenix parent span context payload could not be restored into a valid parent span: "
f"parent_node_execution_id={parent_node_execution_id}"
)
return normalized_carrier
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
@ -177,6 +269,246 @@ def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN)
def _resolve_workflow_session_id(trace_info: WorkflowTraceInfo) -> str:
"""Resolve the workflow session ID for Phoenix workflow spans."""
if trace_info.conversation_id:
return trace_info.conversation_id
parent_workflow_run_id, _ = _resolve_workflow_parent_context(trace_info)
if parent_workflow_run_id:
return parent_workflow_run_id
return trace_info.workflow_run_id
def _resolve_workflow_parent_context(trace_info: BaseTraceInfo) -> tuple[str | None, str | None]:
"""Expose the typed parent context already resolved on the trace info."""
return trace_info.resolved_parent_context
def _resolve_workflow_root_trace_id(trace_info: WorkflowTraceInfo) -> str:
"""Resolve the canonical root trace ID for Phoenix workflow spans."""
trace_correlation_override, _ = _resolve_workflow_parent_context(trace_info)
return trace_correlation_override or trace_info.resolved_trace_id or trace_info.workflow_run_id
class _NodeExecutionIdentityLike(Protocol):
@property
def node_execution_id(self) -> str | None: ...
@property
def node_id(self) -> str: ...
@property
def predecessor_node_id(self) -> str | None: ...
class _NodeExecutionLike(_NodeExecutionIdentityLike, Protocol):
@property
def id(self) -> str: ...
@property
def node_type(self) -> str: ...
@property
def title(self) -> str | None: ...
@property
def inputs(self) -> Mapping[str, Any] | None: ...
@property
def process_data(self) -> Mapping[str, Any] | None: ...
@property
def outputs(self) -> Mapping[str, Any] | None: ...
@property
def status(self) -> WorkflowNodeExecutionStatus: ...
@property
def error(self) -> str | None: ...
@property
def elapsed_time(self) -> float | None: ...
@property
def metadata(self) -> Mapping[Any, Any] | None: ...
@property
def created_at(self) -> datetime | None: ...
_PHOENIX_STRUCTURED_NODE_TYPES = frozenset({"start", "end", "loop", "iteration"})
def _resolve_workflow_span_name(trace_info: WorkflowTraceInfo) -> str:
"""Resolve the Phoenix workflow span display name."""
workflow_run_id = trace_info.workflow_run_id.strip() if trace_info.workflow_run_id else ""
if workflow_run_id:
return f"{TraceTaskName.WORKFLOW_TRACE.value}_{workflow_run_id}"
return TraceTaskName.WORKFLOW_TRACE.value
def _build_node_title_by_id(trace_info: WorkflowTraceInfo) -> dict[str, str]:
"""Build an authoritative node-title index from the persisted workflow graph."""
workflow_data = trace_info.workflow_data
workflow_graph = getattr(workflow_data, "graph_dict", None)
if not isinstance(workflow_graph, Mapping):
workflow_graph = workflow_data.get("graph") if isinstance(workflow_data, Mapping) else None
if not isinstance(workflow_graph, Mapping):
return {}
graph_nodes = workflow_graph.get("nodes")
if not isinstance(graph_nodes, Sequence):
return {}
node_title_by_id: dict[str, str] = {}
for graph_node in graph_nodes:
if not isinstance(graph_node, Mapping):
continue
node_id = graph_node.get("id")
node_data = graph_node.get("data")
if not isinstance(node_id, str) or not isinstance(node_data, Mapping):
continue
node_title = node_data.get("title")
if isinstance(node_title, str) and node_title.strip():
node_title_by_id[node_id] = node_title.strip()
return node_title_by_id
def _resolve_workflow_node_span_name(
node_execution: _NodeExecutionLike,
node_title_by_id: Mapping[str, str] | None = None,
) -> str:
"""Resolve the Phoenix workflow node span display name."""
node_type = str(node_execution.node_type or "")
graph_node_title = None
if node_title_by_id is not None and isinstance(node_execution.node_id, str):
graph_node_title = node_title_by_id.get(node_execution.node_id)
node_title = graph_node_title or (node_execution.title.strip() if isinstance(node_execution.title, str) else "")
if node_title:
return f"{node_type}_{node_title}"
return node_type
def _get_node_execution_id(node_execution: _NodeExecutionIdentityLike) -> str:
"""Return the stable execution identifier for a workflow node execution."""
return str(getattr(node_execution, "id", None) or node_execution.node_execution_id)
def _build_execution_id_by_node_id(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]:
"""Index unique workflow graph node ids by execution id.
This Phoenix-local hierarchy reconstruction intentionally drops ambiguous
node ids instead of guessing based on repository order. That keeps parent
selection deterministic until upstream tracing exposes explicit parent span
data for repeated executions.
"""
execution_id_by_node_id: dict[str, str] = {}
ambiguous_node_ids: set[str] = set()
for node_execution in node_executions:
node_id = node_execution.node_id
if not isinstance(node_id, str):
continue
execution_id = _get_node_execution_id(node_execution)
if node_id in ambiguous_node_ids:
continue
existing_execution_id = execution_id_by_node_id.get(node_id)
if existing_execution_id is None:
execution_id_by_node_id[node_id] = execution_id
continue
if existing_execution_id != execution_id:
ambiguous_node_ids.add(node_id)
execution_id_by_node_id.pop(node_id, None)
return execution_id_by_node_id
def _build_graph_parent_index(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]:
"""Build an execution-id parent index from predecessor node ids."""
execution_id_by_node_id = _build_execution_id_by_node_id(node_executions)
graph_parent_index: dict[str, str] = {}
for node_execution in node_executions:
predecessor_node_id = node_execution.predecessor_node_id
if not isinstance(predecessor_node_id, str):
continue
predecessor_execution_id = execution_id_by_node_id.get(predecessor_node_id)
if predecessor_execution_id is not None:
execution_id = _get_node_execution_id(node_execution)
graph_parent_index[execution_id] = predecessor_execution_id
return graph_parent_index
def _resolve_structured_parent_execution_id(
node_execution: object, execution_id_by_node_id: Mapping[str, str]
) -> str | None:
"""Resolve Phoenix-local structured parents from loop/iteration node ids.
Any execution carrying ``iteration_id`` or ``loop_id`` belongs to an
enclosing structured node. When predecessor node ids are ambiguous because
the graph node repeats inside that structure, Phoenix can still keep the
child span under the enclosing loop/iteration span without relying on
execution-order heuristics.
"""
execution_metadata = getattr(node_execution, "execution_metadata_dict", None)
if not isinstance(execution_metadata, Mapping):
execution_metadata = getattr(node_execution, "metadata", None)
if not isinstance(execution_metadata, Mapping):
execution_metadata = {}
for enclosing_node_id in (
getattr(node_execution, "iteration_id", None),
getattr(node_execution, "loop_id", None),
execution_metadata.get("iteration_id"),
execution_metadata.get("loop_id"),
):
if not isinstance(enclosing_node_id, str):
continue
enclosing_execution_id = execution_id_by_node_id.get(enclosing_node_id)
if enclosing_execution_id is not None:
return enclosing_execution_id
return None
def _resolve_node_parent(
execution_id: str,
predecessor_execution_id: str | None,
structured_parent_execution_id: str | None,
span_by_execution_id: Mapping[str, Span],
graph_parent_index: Mapping[str, str],
workflow_span: Span,
) -> Span:
"""Resolve the parent span for a workflow node execution."""
if predecessor_execution_id is not None:
predecessor_span = span_by_execution_id.get(predecessor_execution_id)
if predecessor_span is not None:
return predecessor_span
graph_parent_execution_id = graph_parent_index.get(execution_id)
if graph_parent_execution_id is not None:
graph_parent_span = span_by_execution_id.get(graph_parent_execution_id)
if graph_parent_span is not None:
return graph_parent_span
if structured_parent_execution_id is not None:
structured_parent_span = span_by_execution_id.get(structured_parent_execution_id)
if structured_parent_span is not None:
return structured_parent_span
return workflow_span
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
@ -189,6 +521,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.propagator = TraceContextTextMapPropagator()
self.dify_trace_ids: set[str] = set()
self.root_span_carriers: dict[str, dict[str, str]] = {}
self.carrier: dict[str, str] = {}
def trace(self, trace_info: BaseTraceInfo):
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
@ -235,13 +569,41 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
file_list=safe_json_dumps(file_list),
query=trace_info.query or "",
)
workflow_session_id = _resolve_workflow_session_id(trace_info)
parent_workflow_run_id, parent_node_execution_id = _resolve_workflow_parent_context(trace_info)
logger.info(
"[Arize/Phoenix] Workflow session resolution: workflow_run_id=%s conversation_id=%s "
"parent_workflow_run_id=%s parent_node_execution_id=%s resolved_session_id=%s",
trace_info.workflow_run_id,
trace_info.conversation_id,
parent_workflow_run_id,
parent_node_execution_id,
workflow_session_id,
)
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
self.ensure_root_span(dify_trace_id)
root_span_context = self.propagator.extract(carrier=self.carrier)
if parent_node_execution_id:
workflow_parent_carrier = _resolve_published_parent_span_context(parent_node_execution_id)
else:
root_trace_id = _resolve_workflow_root_trace_id(trace_info)
workflow_root_span_name: str | None = trace_info.workflow_run_id
if not isinstance(workflow_root_span_name, str) or not workflow_root_span_name.strip():
workflow_root_span_name = None
workflow_parent_carrier = self.ensure_root_span(
root_trace_id,
root_span_name=workflow_root_span_name,
root_span_attributes={
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
},
)
workflow_span_context = self.propagator.extract(carrier=workflow_parent_carrier)
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
name=_resolve_workflow_span_name(trace_info),
attributes={
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
@ -249,10 +611,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
SpanAttributes.SESSION_ID: workflow_session_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=root_span_context,
context=workflow_span_context,
)
# Through workflow_run_id, get all_nodes_execution using repository
@ -276,16 +638,50 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
workflow_execution_id=trace_info.workflow_run_id
)
node_title_by_id = _build_node_title_by_id(trace_info)
execution_id_by_node_id = _build_execution_id_by_node_id(workflow_node_executions)
graph_parent_index = _build_graph_parent_index(workflow_node_executions)
node_execution_by_execution_id = {
_get_node_execution_id(node_execution): node_execution for node_execution in workflow_node_executions
}
span_by_execution_id: dict[str, Span] = {}
emitting_execution_ids: set[str] = set()
workflow_span_error: Exception | str | None = trace_info.error
try:
for node_execution in workflow_node_executions:
def emit_node_span(node_execution: _NodeExecutionLike) -> Span:
execution_id = _get_node_execution_id(node_execution)
existing_span = span_by_execution_id.get(execution_id)
if existing_span is not None:
return existing_span
graph_parent_execution_id = graph_parent_index.get(execution_id)
structured_parent_execution_id = _resolve_structured_parent_execution_id(
node_execution, execution_id_by_node_id
)
if execution_id not in emitting_execution_ids:
emitting_execution_ids.add(execution_id)
try:
for parent_execution_id in (graph_parent_execution_id, structured_parent_execution_id):
if parent_execution_id is None or parent_execution_id == execution_id:
continue
if parent_execution_id in span_by_execution_id:
continue
parent_node_execution = node_execution_by_execution_id.get(parent_execution_id)
if parent_node_execution is not None:
emit_node_span(parent_node_execution)
finally:
emitting_execution_ids.discard(execution_id)
tenant_id = trace_info.tenant_id # Use from trace_info instead
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
inputs_value = node_execution.inputs or {}
outputs_value = node_execution.outputs or {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
elapsed_time = node_execution.elapsed_time or 0
finished_at = created_at + timedelta(seconds=elapsed_time)
process_data = node_execution.process_data or {}
@ -324,9 +720,17 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
workflow_span_context = set_span_in_context(workflow_span)
parent_span = _resolve_node_parent(
execution_id=execution_id,
predecessor_execution_id=None,
structured_parent_execution_id=structured_parent_execution_id,
span_by_execution_id=span_by_execution_id,
graph_parent_index=graph_parent_index,
workflow_span=workflow_span,
)
workflow_span_context = set_span_in_context(parent_span)
node_span = self.tracer.start_span(
name=node_execution.node_type,
name=_resolve_workflow_node_span_name(node_execution, node_title_by_id),
attributes={
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
@ -334,13 +738,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
SpanAttributes.SESSION_ID: workflow_session_id or "",
},
start_time=datetime_to_nanos(created_at),
context=workflow_span_context,
)
span_by_execution_id[execution_id] = node_span
node_span_error: Exception | str | None = None
try:
if node_execution.node_type == "tool":
parent_span_carrier: dict[str, str] = {}
with use_span(node_span, end_on_exit=False):
self.propagator.inject(carrier=parent_span_carrier)
_publish_parent_span_context(execution_id, parent_span_carrier)
if node_execution.node_type == "llm":
llm_attributes: dict[str, Any] = {
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
@ -362,17 +773,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
)
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
except Exception as e:
node_span_error = e
raise
finally:
if node_execution.status == WorkflowNodeExecutionStatus.FAILED:
if node_span_error is not None:
set_span_status(node_span, node_span_error)
elif node_execution.status == WorkflowNodeExecutionStatus.FAILED:
set_span_status(node_span, node_execution.error)
else:
set_span_status(node_span)
node_span.end(end_time=datetime_to_nanos(finished_at))
return node_span
for node_execution in workflow_node_executions:
emit_node_span(node_execution)
except Exception as e:
workflow_span_error = e
raise
finally:
if trace_info.error:
set_span_status(workflow_span, trace_info.error)
else:
set_span_status(workflow_span)
set_span_status(workflow_span, workflow_span_error)
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def message_trace(self, trace_info: MessageTraceInfo):
@ -735,22 +1155,39 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def ensure_root_span(self, dify_trace_id: str | None):
def ensure_root_span(
self,
dify_trace_id: str | None,
*,
root_span_name: str | None = None,
root_span_attributes: Mapping[str, AttributeValue] | None = None,
):
"""Ensure a unique root span exists for the given Dify trace ID."""
if str(dify_trace_id) not in self.dify_trace_ids:
self.carrier: dict[str, str] = {}
trace_key = str(dify_trace_id)
if trace_key not in self.dify_trace_ids:
carrier: dict[str, str] = {}
root_span = self.tracer.start_span(name="Dify")
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
root_span.set_attribute("dify_project_name", str(self.project))
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
span_name = root_span_name.strip() if isinstance(root_span_name, str) and root_span_name.strip() else "Dify"
root_span_attributes_dict: dict[str, AttributeValue] = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
"dify_project_name": str(self.project),
"dify_trace_id": trace_key,
}
if root_span_attributes:
root_span_attributes_dict.update(root_span_attributes)
root_span = self.tracer.start_span(name=span_name, attributes=root_span_attributes_dict)
with use_span(root_span, end_on_exit=False):
self.propagator.inject(carrier=self.carrier)
self.propagator.inject(carrier=carrier)
set_span_status(root_span)
root_span.end()
self.dify_trace_ids.add(str(dify_trace_id))
self.dify_trace_ids.add(trace_key)
self.root_span_carriers[trace_key] = carrier
self.carrier = self.root_span_carriers[trace_key]
return self.carrier
def api_check(self):
try:

View File

@ -1,36 +0,0 @@
from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
from openinference.semconv.trace import OpenInferenceSpanKindValues
from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes
class TestGetNodeSpanKind:
"""Tests for _get_node_span_kind helper."""
def test_all_node_types_are_mapped_correctly(self):
"""Ensure every built-in node type is mapped to the correct span kind."""
# Mappings for node types that have a specialised span kind.
special_mappings = {
BuiltinNodeTypes.LLM: OpenInferenceSpanKindValues.LLM,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER,
BuiltinNodeTypes.TOOL: OpenInferenceSpanKindValues.TOOL,
BuiltinNodeTypes.AGENT: OpenInferenceSpanKindValues.AGENT,
}
# Test that every built-in node type is mapped to the correct span kind.
# Node types not in `special_mappings` should default to CHAIN.
for node_type in BUILT_IN_NODE_TYPES:
expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN)
actual_span_kind = _get_node_span_kind(node_type)
assert actual_span_kind == expected_span_kind, (
f"Node type {node_type!r} was mapped to {actual_span_kind}, but {expected_span_kind} was expected."
)
def test_unknown_string_defaults_to_chain(self):
"""An unrecognised node type string should still return CHAIN."""
assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN
def test_stale_dataset_retrieval_not_in_mapping(self):
"""The old 'dataset_retrieval' string was never a valid NodeType value;
make sure it is not present in the mapping dictionary."""
assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND

View File

@ -65,35 +65,18 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
}
file_list = values.get("file_list", [])
if isinstance(v, str):
if field_name == "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
return {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
if field_name == "inputs":
data = {
"messages": v,
match field_name:
case "inputs":
return {
"messages": {
"role": "user",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
elif field_name == "outputs":
data = {
case "outputs":
return {
"choices": {
"role": "ai",
"content": v,
@ -101,6 +84,29 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
"file_list": file_list,
},
}
case _:
pass
elif isinstance(v, list):
data = {}
if len(v) > 0 and isinstance(v[0], dict):
# rename text to content
v = replace_text_with_content(data=v)
match field_name:
case "inputs":
data = {
"messages": v,
}
case "outputs":
data = {
"choices": {
"role": "ai",
"content": v,
"usage_metadata": usage_metadata,
"file_list": file_list,
},
}
case _:
pass
return data
else:
return {

View File

@ -64,7 +64,9 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
# trace_id must equal the root run's run_id (LangSmith protocol); external trace_id
# cannot be used here as it would cause HTTP 400.
trace_id = trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
message_dotted_order = (
@ -77,6 +79,8 @@ class LangSmithDataTrace(BaseTraceInstance):
)
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.trace_id:
metadata["external_trace_id"] = trace_info.trace_id
if trace_info.message_id:
message_run = LangSmithRunModel(

View File

@ -208,13 +208,17 @@ def test_workflow_trace(trace_instance, monkeypatch: pytest.MonkeyPatch):
assert call_args[0].id == "msg-1"
assert call_args[0].name == TraceTaskName.MESSAGE_TRACE
# trace_id must equal root run's id (message_id), not the external trace_id "trace-1"
assert call_args[0].trace_id == "msg-1"
assert call_args[1].id == "run-1"
assert call_args[1].name == TraceTaskName.WORKFLOW_TRACE
assert call_args[1].parent_run_id == "msg-1"
assert call_args[1].trace_id == "msg-1"
assert call_args[2].id == "node-llm"
assert call_args[2].run_type == LangSmithRunType.llm
assert call_args[2].trace_id == "msg-1"
assert call_args[3].id == "node-other"
assert call_args[3].run_type == LangSmithRunType.tool
@ -604,3 +608,83 @@ def test_get_project_url_error(trace_instance):
trace_instance.langsmith_client.get_run_url.side_effect = Exception("error")
with pytest.raises(ValueError, match="LangSmith get run url failed: error"):
trace_instance.get_project_url()
def _make_workflow_trace_info(
*, message_id: str | None, workflow_run_id: str, trace_id: str | None
) -> WorkflowTraceInfo:
workflow_data = MagicMock()
workflow_data.created_at = _dt()
workflow_data.finished_at = _dt() + timedelta(seconds=1)
return WorkflowTraceInfo(
tenant_id="tenant-1",
workflow_id="wf-1",
workflow_run_id=workflow_run_id,
workflow_run_inputs={},
workflow_run_outputs={},
workflow_run_status="succeeded",
workflow_run_version="1.0",
workflow_run_elapsed_time=1.0,
total_tokens=0,
file_list=[],
query="q",
message_id=message_id,
conversation_id="conv-1" if message_id else None,
start_time=_dt(),
end_time=_dt() + timedelta(seconds=1),
trace_id=trace_id,
metadata={"app_id": "app-1"},
workflow_app_log_id=None,
error=None,
workflow_data=workflow_data,
)
def _patch_workflow_trace_deps(monkeypatch, trace_instance):
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
repo = MagicMock()
repo.get_by_workflow_execution.return_value = []
factory = MagicMock()
factory.create_workflow_node_execution_repository.return_value = repo
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", factory)
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
trace_instance.add_run = MagicMock()
def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypatch):
"""Chatflow with external trace_id: LangSmith trace_id must be message_id, not external."""
trace_info = _make_workflow_trace_info(
message_id="msg-abc",
workflow_run_id="run-xyz",
trace_id="external-999",
)
_patch_workflow_trace_deps(monkeypatch, trace_instance)
trace_instance.workflow_trace(trace_info)
calls = [c[0][0] for c in trace_instance.add_run.call_args_list]
# message run (root) and workflow run (child) must both use message_id as trace_id
assert calls[0].id == "msg-abc"
assert calls[0].trace_id == "msg-abc"
assert calls[1].id == "run-xyz"
assert calls[1].trace_id == "msg-abc"
# external_trace_id preserved in metadata
assert trace_info.metadata.get("external_trace_id") == "external-999"
def test_workflow_trace_id_pure_workflow_uses_run_id(trace_instance, monkeypatch):
"""Pure workflow (no message_id) with external trace_id: trace_id must be workflow_run_id."""
trace_info = _make_workflow_trace_info(
message_id=None,
workflow_run_id="run-xyz",
trace_id="external-999",
)
_patch_workflow_trace_deps(monkeypatch, trace_instance)
trace_instance.workflow_trace(trace_info)
calls = [c[0][0] for c in trace_instance.add_run.call_args_list]
# workflow run is the root; trace_id must equal its run_id
assert calls[0].id == "run-xyz"
assert calls[0].trace_id == "run-xyz"

View File

@ -81,14 +81,15 @@ class OpenSearchConfig(BaseModel):
pool_maxsize=20,
)
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")
match self.auth_method:
case AuthMethod.BASIC:
logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password)
elif self.auth_method == "aws_managed_iam":
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password)
case AuthMethod.AWS_MANAGED_IAM:
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth()
params["http_auth"] = self.create_aws_managed_iam_auth()
return params

View File

@ -1,12 +1,12 @@
[project]
name = "dify-api"
version = "1.14.0"
version = "1.14.1"
requires-python = "~=3.12.0"
dependencies = [
# Legacy: mature and widely deployed
"bleach>=6.3.0",
"boto3>=1.43.3",
"boto3>=1.43.6",
"celery>=5.6.3",
"croniter>=6.2.2",
"flask>=3.1.3,<4.0.0",
@ -14,8 +14,8 @@ dependencies = [
"gevent>=26.4.0",
"gevent-websocket>=0.10.1",
"gmpy2>=2.3.0",
"google-api-python-client>=2.195.0",
"gunicorn>=25.3.0",
"google-api-python-client>=2.196.0",
"gunicorn>=26.0.0",
"psycogreen>=1.0.2",
"psycopg2-binary>=2.9.12",
"python-socketio>=5.13.0",
@ -31,7 +31,7 @@ dependencies = [
"flask-migrate>=4.1.0,<5.0.0",
"flask-orjson>=2.0.0,<3.0.0",
"flask-restx>=1.3.2,<2.0.0",
"google-cloud-aiplatform>=1.149.0,<2.0.0",
"google-cloud-aiplatform>=1.151.0,<2.0.0",
"httpx[socks]>=0.28.1,<1.0.0",
"opentelemetry-distro>=0.62b1,<1.0.0",
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
@ -45,7 +45,7 @@ dependencies = [
# Emerging: newer and fast-moving, use compatible pins
"fastopenapi[flask]~=0.7.0",
"graphon~=0.3.0",
"graphon~=0.3.1",
"httpx-sse~=0.4.0",
"json-repair~=0.59.4",
]
@ -191,7 +191,7 @@ storage = [
"google-cloud-storage>=3.10.1",
"opendal>=0.46.0",
"oss2>=2.19.1",
"supabase>=2.29.0",
"supabase>=2.30.0",
"tos>=2.9.0",
]

View File

@ -1,9 +1,10 @@
import json
import logging
from typing import Any, TypedDict, cast
from typing import Any, Literal, TypedDict, cast
import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination
from pydantic import BaseModel, Field
from sqlalchemy import select
from configs import dify_config
@ -31,39 +32,59 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
logger = logging.getLogger(__name__)
class AppListParams(BaseModel):
page: int = Field(default=1, ge=1)
limit: int = Field(default=20, ge=1, le=100)
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = "all"
name: str | None = None
tag_ids: list[str] | None = None
is_created_by_me: bool | None = None
class CreateAppParams(BaseModel):
name: str = Field(min_length=1)
description: str | None = None
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
api_rph: int = 0
api_rpm: int = 0
max_active_requests: int | None = None
class AppService:
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict[str, Any]) -> Pagination | None:
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
"""
Get app list with pagination
:param user_id: user id
:param tenant_id: tenant id
:param args: request args
:param params: query parameters
:return:
"""
filters = [App.tenant_id == tenant_id, App.is_universal == False]
if args["mode"] == "workflow":
if params.mode == "workflow":
filters.append(App.mode == AppMode.WORKFLOW)
elif args["mode"] == "completion":
elif params.mode == "completion":
filters.append(App.mode == AppMode.COMPLETION)
elif args["mode"] == "chat":
elif params.mode == "chat":
filters.append(App.mode == AppMode.CHAT)
elif args["mode"] == "advanced-chat":
elif params.mode == "advanced-chat":
filters.append(App.mode == AppMode.ADVANCED_CHAT)
elif args["mode"] == "agent-chat":
elif params.mode == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT)
if args.get("is_created_by_me", False):
if params.is_created_by_me:
filters.append(App.created_by == user_id)
if args.get("name"):
if params.name:
from libs.helper import escape_like_pattern
name = args["name"][:30]
name = params.name[:30]
escaped_name = escape_like_pattern(name)
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
# Check if tag_ids is not empty to avoid WHERE false condition
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
if params.tag_ids and len(params.tag_ids) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids)
if target_ids and len(target_ids) > 0:
filters.append(App.id.in_(target_ids))
else:
@ -71,21 +92,21 @@ class AppService:
app_models = db.paginate(
sa.select(App).where(*filters).order_by(App.created_at.desc()),
page=args["page"],
per_page=args["limit"],
page=params.page,
per_page=params.limit,
error_out=False,
)
return app_models
def create_app(self, tenant_id: str, args: dict[str, Any], account: Account) -> App:
def create_app(self, tenant_id: str, params: CreateAppParams, account: Account) -> App:
"""
Create app
:param tenant_id: tenant id
:param args: request args
:param params: app creation parameters
:param account: Account instance
"""
app_mode = AppMode.value_of(args["mode"])
app_mode = AppMode.value_of(params.mode)
app_template = default_app_templates[app_mode]
# get model config
@ -143,15 +164,16 @@ class AppService:
default_model_config["model"] = json.dumps(default_model_dict)
app = App(**app_template["app"])
app.name = args["name"]
app.description = args.get("description", "")
app.mode = args["mode"]
app.icon_type = args.get("icon_type", "emoji")
app.icon = args["icon"]
app.icon_background = args["icon_background"]
app.name = params.name
app.description = params.description or ""
app.mode = app_mode
app.icon_type = IconType(params.icon_type) if params.icon_type else IconType.EMOJI
app.icon = params.icon
app.icon_background = params.icon_background
app.tenant_id = tenant_id
app.api_rph = args.get("api_rph", 0)
app.api_rpm = args.get("api_rpm", 0)
app.api_rph = params.api_rph
app.api_rpm = params.api_rpm
app.max_active_requests = params.max_active_requests
app.created_by = account.id
app.updated_by = account.id

View File

@ -7,9 +7,10 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, TypedDict, cast
from typing import Annotated, Any, Literal, TypedDict, cast
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from redis.exceptions import LockNotOwnedError
from sqlalchemy import delete, exists, func, select, update
from sqlalchemy.orm import Session, sessionmaker
@ -108,7 +109,7 @@ logger = logging.getLogger(__name__)
class ProcessRulesDict(TypedDict):
mode: str
mode: ProcessRuleMode
rules: dict[str, Any]
@ -117,6 +118,86 @@ class AutoDisableLogsDict(TypedDict):
count: int
class _EstimatePreProcessingRule(BaseModel):
id: str = Field(min_length=1)
enabled: bool
@field_validator("id")
@classmethod
def _validate_id(cls, v: str) -> str:
if v not in DatasetProcessRule.PRE_PROCESSING_RULES:
raise ValueError("Process rule pre_processing_rules id is invalid")
return v
class _EstimateSegmentation(BaseModel):
separator: str = Field(min_length=1)
max_tokens: int = Field(gt=0)
class _EstimateRules(BaseModel):
pre_processing_rules: list[_EstimatePreProcessingRule]
segmentation: _EstimateSegmentation
@field_validator("pre_processing_rules")
@classmethod
def _deduplicate(cls, v: list[_EstimatePreProcessingRule]) -> list[_EstimatePreProcessingRule]:
seen: dict[str, _EstimatePreProcessingRule] = {}
for rule in v:
seen[rule.id] = rule
return list(seen.values())
class _SummaryIndexSettingDisabled(BaseModel):
enable: Literal[False] = False
class _SummaryIndexSettingEnabled(BaseModel):
enable: Literal[True]
model_name: str = Field(min_length=1)
model_provider_name: str = Field(min_length=1)
_SummaryIndexSetting = Annotated[
_SummaryIndexSettingDisabled | _SummaryIndexSettingEnabled,
Field(discriminator="enable"),
]
class _AutomaticProcessRule(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal[ProcessRuleMode.AUTOMATIC]
summary_index_setting: _SummaryIndexSetting | None = None
class _CustomProcessRule(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal[ProcessRuleMode.CUSTOM]
rules: _EstimateRules
summary_index_setting: _SummaryIndexSetting | None = None
class _HierarchicalProcessRule(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal[ProcessRuleMode.HIERARCHICAL]
rules: _EstimateRules
summary_index_setting: _SummaryIndexSetting | None = None
_EstimateProcessRule = Annotated[
_AutomaticProcessRule | _CustomProcessRule | _HierarchicalProcessRule,
Field(discriminator="mode"),
]
class _EstimateArgs(BaseModel):
info_list: dict[str, Any]
process_rule: _EstimateProcessRule
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
@ -204,7 +285,7 @@ class DatasetService:
mode = dataset_process_rule.mode
rules = dataset_process_rule.rules_dict or {}
else:
mode = str(DocumentService.DEFAULT_RULES["mode"])
mode = ProcessRuleMode(DocumentService.DEFAULT_RULES["mode"])
rules = dict(DocumentService.DEFAULT_RULES.get("rules") or {})
return {"mode": mode, "rules": rules}
@ -1984,7 +2065,7 @@ class DocumentService:
if process_rule.rules:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
mode=ProcessRuleMode(process_rule.mode),
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
@ -1995,7 +2076,7 @@ class DocumentService:
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
mode=ProcessRuleMode.AUTOMATIC,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
@ -2572,14 +2653,14 @@ class DocumentService:
if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
mode=ProcessRuleMode(process_rule.mode),
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
created_by=account.id,
)
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
mode=process_rule.mode,
mode=ProcessRuleMode.AUTOMATIC,
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
created_by=account.id,
)
@ -2851,94 +2932,16 @@ class DocumentService:
@classmethod
def estimate_args_validate(cls, args: dict[str, Any]):
if "info_list" not in args or not args["info_list"]:
raise ValueError("Data source info is required")
if not isinstance(args["info_list"], dict):
raise ValueError("Data info is invalid")
if "process_rule" not in args or not args["process_rule"]:
raise ValueError("Process rule is required")
if not isinstance(args["process_rule"], dict):
raise ValueError("Process rule is invalid")
if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]:
raise ValueError("Process rule mode is required")
if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
raise ValueError("Process rule mode is invalid")
if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC:
args["process_rule"]["rules"] = {}
else:
if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
raise ValueError("Process rule rules is required")
if not isinstance(args["process_rule"]["rules"], dict):
raise ValueError("Process rule rules is invalid")
if (
"pre_processing_rules" not in args["process_rule"]["rules"]
or args["process_rule"]["rules"]["pre_processing_rules"] is None
):
raise ValueError("Process rule pre_processing_rules is required")
if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list):
raise ValueError("Process rule pre_processing_rules is invalid")
unique_pre_processing_rule_dicts = {}
for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]:
if "id" not in pre_processing_rule or not pre_processing_rule["id"]:
raise ValueError("Process rule pre_processing_rules id is required")
if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES:
raise ValueError("Process rule pre_processing_rules id is invalid")
if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None:
raise ValueError("Process rule pre_processing_rules enabled is required")
if not isinstance(pre_processing_rule["enabled"], bool):
raise ValueError("Process rule pre_processing_rules enabled is invalid")
unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule
args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values())
if (
"segmentation" not in args["process_rule"]["rules"]
or args["process_rule"]["rules"]["segmentation"] is None
):
raise ValueError("Process rule segmentation is required")
if not isinstance(args["process_rule"]["rules"]["segmentation"], dict):
raise ValueError("Process rule segmentation is invalid")
if (
"separator" not in args["process_rule"]["rules"]["segmentation"]
or not args["process_rule"]["rules"]["segmentation"]["separator"]
):
raise ValueError("Process rule segmentation separator is required")
if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str):
raise ValueError("Process rule segmentation separator is invalid")
if (
"max_tokens" not in args["process_rule"]["rules"]["segmentation"]
or not args["process_rule"]["rules"]["segmentation"]["max_tokens"]
):
raise ValueError("Process rule segmentation max_tokens is required")
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
# valid summary index setting
summary_index_setting = args["process_rule"].get("summary_index_setting")
if summary_index_setting and summary_index_setting.get("enable"):
if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
raise ValueError("Summary index model name is required")
if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
raise ValueError("Summary index model provider name is required")
try:
validated = _EstimateArgs.model_validate(args)
except ValidationError as e:
first = e.errors()[0]
original = first.get("ctx", {}).get("error")
raise ValueError(str(original) if isinstance(original, ValueError) else first["msg"]) from e
process_rule_dict = validated.process_rule.model_dump(exclude_none=True)
if validated.process_rule.mode == ProcessRuleMode.AUTOMATIC:
process_rule_dict["rules"] = {}
args["process_rule"] = process_rule_dict
@staticmethod
def batch_update_document_status(

View File

@ -166,7 +166,7 @@ class SystemFeatureModel(BaseModel):
enable_email_code_login: bool = False
enable_email_password_login: bool = True
enable_social_oauth_login: bool = False
enable_collaboration_mode: bool = False
enable_collaboration_mode: bool = True
is_allow_register: bool = False
is_allow_create_workspace: bool = False
is_email_setup: bool = False

View File

@ -111,6 +111,7 @@ class VectorService:
"dataset_id": segment.dataset_id,
},
)
assert segment.index_node_id
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
# update vector index
vector = Vector(dataset=dataset)
@ -138,6 +139,7 @@ class VectorService:
regenerate: bool = False,
):
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
assert segment.index_node_id
if regenerate:
# delete child chunks
index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)

View File

@ -1066,8 +1066,13 @@ class WorkflowService:
)
rendered_content = node.render_form_content_before_submission()
selected_action = next(
(user_action for user_action in node_data.user_actions if user_action.id == action),
None,
)
outputs: dict[str, Any] = dict(form_inputs)
outputs["__action_id"] = action
outputs["__action_value"] = selected_action.title if selected_action else ""
outputs["__rendered_content"] = node.render_form_content_with_outputs(
rendered_content, outputs, node_data.outputs_field_names()
)

View File

@ -50,7 +50,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
segment_ids = [segment.id for segment in segments]
# Collect image file IDs from segment content

View File

@ -19,6 +19,7 @@ from graphon.model_runtime.entities.model_entities import ModelType
from libs import helper
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.enums import SegmentStatus
from models.model import UploadFile
from services.vector_service import VectorService
@ -156,7 +157,7 @@ def batch_create_segment_to_index_task(
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
status=SegmentStatus.COMPLETED,
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == IndexStructureType.QA_INDEX:

View File

@ -53,7 +53,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
segment_contents = [segment.content for segment in segments]
except Exception:
logger.exception("Cleaned document when document deleted failed")

View File

@ -38,7 +38,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
for document_id in document_ids:
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
total_index_node_ids.extend([segment.index_node_id for segment in segments])
total_index_node_ids.extend([segment.index_node_id for segment in segments if segment.index_node_id])
# Wrap vector / keyword index cleanup in try/except so that a transient
# failure here (e.g. billing API hiccup propagated via FeatureService when

View File

@ -9,6 +9,7 @@ from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
from models.enums import SegmentStatus
logger = logging.getLogger(__name__)
@ -30,7 +31,7 @@ def disable_segment_from_index_task(segment_id: str):
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
if segment.status != "completed":
if segment.status != SegmentStatus.COMPLETED:
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
return
@ -59,6 +60,7 @@ def disable_segment_from_index_task(segment_id: str):
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
assert segment.index_node_id
index_processor.clean(dataset, [segment.index_node_id])
# Disable summary index for this segment

View File

@ -55,7 +55,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
if dataset.is_multimodal:
segment_ids = [segment.id for segment in segments]
segment_attachment_bindings = session.scalars(

View File

@ -69,7 +69,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()

View File

@ -45,7 +45,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
clean_success = False
try:

View File

@ -137,7 +137,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
select(DocumentSegment).where(DocumentSegment.document_id == document.id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

View File

@ -1,11 +1,31 @@
"""
Celery task for asynchronous ops trace dispatch.
Trace providers may report explicitly retryable dispatch failures through the
core retryable exception contract. The task preserves the payload file only
when Celery accepts the retry request; successful dispatches and terminal
failures clean up the stored payload.
One concrete producer today is Phoenix nested workflow tracing. The outer
workflow tool span publishes a restorable parent span context asynchronously,
while the nested workflow trace may be picked up by Celery first. In that
ordering window, the provider raises a retryable core exception instead of
dropping the trace or emitting it under the wrong parent. The task intentionally
does not know that the provider is Phoenix; it only honors the core retryable
dispatch contract.
"""
import json
import logging
from celery import shared_task
from celery.exceptions import Retry
from flask import current_app
from configs import dify_config
from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
from core.ops.entities.trace_entity import trace_info_info_map
from core.ops.exceptions import RetryableTraceDispatchError
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -14,9 +34,17 @@ from models.workflow import WorkflowRun
logger = logging.getLogger(__name__)
_RETRYABLE_TRACE_DISPATCH_LIMIT = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES
_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS
@shared_task(queue="ops_trace")
def process_trace_tasks(file_info):
@shared_task(
queue="ops_trace",
bind=True,
max_retries=_RETRYABLE_TRACE_DISPATCH_LIMIT,
default_retry_delay=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS,
)
def process_trace_tasks(self, file_info):
"""
Async process trace tasks
Usage: process_trace_tasks.delay(tasks_data)
@ -29,6 +57,7 @@ def process_trace_tasks(file_info):
file_data = json.loads(storage.load(file_path))
trace_info = file_data.get("trace_info")
trace_info_type = file_data.get("trace_info_type")
enterprise_trace_dispatched = bool(file_data.get("_enterprise_trace_dispatched"))
trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
if trace_info.get("message_data"):
@ -38,6 +67,8 @@ def process_trace_tasks(file_info):
if trace_info.get("documents"):
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
should_delete_file = True
try:
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
@ -45,30 +76,66 @@ def process_trace_tasks(file_info):
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
if is_ee_telemetry_enabled():
if is_ee_telemetry_enabled() and not enterprise_trace_dispatched:
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
try:
EnterpriseOtelTrace().trace(trace_info)
except Exception:
logger.exception("Enterprise trace failed for app_id: %s", app_id)
else:
file_data["_enterprise_trace_dispatched"] = True
enterprise_trace_dispatched = True
if trace_instance:
with current_app.app_context():
trace_instance.trace(trace_info)
logger.info("Processing trace tasks success, app_id: %s", app_id)
except RetryableTraceDispatchError as e:
# Retryable dispatch failures represent a transient provider-side
# ordering gap, not corrupt payload data. Keep the payload only after
# Celery accepts the retry request; otherwise this attempt becomes a
# terminal failure and the stored file is cleaned up in `finally`.
#
# Enterprise telemetry runs before provider dispatch. If it already ran
# and provider dispatch asks for a retry, persist that private flag so
# the next attempt does not emit the same enterprise trace twice.
if self.request.retries >= _RETRYABLE_TRACE_DISPATCH_LIMIT:
logger.exception("Retryable trace dispatch budget exhausted, app_id: %s", app_id)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
else:
logger.warning(
"Retryable trace dispatch failure, scheduling retry %s/%s for app_id %s: %s",
self.request.retries + 1,
_RETRYABLE_TRACE_DISPATCH_LIMIT,
app_id,
e,
)
try:
if enterprise_trace_dispatched:
storage.save(file_path, json.dumps(file_data).encode("utf-8"))
raise self.retry(exc=e, countdown=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS)
except Retry:
should_delete_file = False
raise
except Exception:
logger.exception("Failed to schedule trace dispatch retry, app_id: %s", app_id)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
except Exception as e:
logger.exception("Processing trace tasks failed, app_id: %s", app_id)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
finally:
try:
storage.delete(file_path)
except Exception as e:
logger.warning(
"Failed to delete trace file %s for app_id %s: %s",
file_path,
app_id,
e,
)
if should_delete_file:
try:
storage.delete(file_path)
except Exception as e:
logger.warning(
"Failed to delete trace file %s for app_id %s: %s",
file_path,
app_id,
e,
)

View File

@ -61,7 +61,7 @@ def remove_document_from_index_task(document_id: str):
except Exception as e:
logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e))
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
if index_node_ids:
try:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)

View File

@ -85,7 +85,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

View File

@ -70,7 +70,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

View File

@ -91,7 +91,11 @@ def init_llm_node(config: dict) -> LLMNode:
return node
def test_execute_llm():
def _mock_db_session_close(monkeypatch) -> None:
monkeypatch.setattr(db.session, "close", MagicMock())
def test_execute_llm(monkeypatch):
node = init_llm_node(
config={
"id": "llm",
@ -118,7 +122,7 @@ def test_execute_llm():
},
)
db.session.close = MagicMock()
_mock_db_session_close(monkeypatch)
def build_mock_model_instance() -> MagicMock:
from decimal import Decimal
@ -195,7 +199,7 @@ def test_execute_llm():
assert item.node_run_result.outputs.get("usage", {})["total_tokens"] > 0
def test_execute_llm_with_jinja2():
def test_execute_llm_with_jinja2(monkeypatch):
"""
Test execute LLM node with jinja2
"""
@ -233,8 +237,7 @@ def test_execute_llm_with_jinja2():
},
)
# Mock db.session.close()
db.session.close = MagicMock()
_mock_db_session_close(monkeypatch)
def build_mock_model_instance() -> MagicMock:
from decimal import Decimal

View File

@ -83,7 +83,11 @@ def init_parameter_extractor_node(config: dict, memory=None):
return node
def test_function_calling_parameter_extractor(setup_model_mock):
def _mock_db_session_close(monkeypatch) -> None:
monkeypatch.setattr(db.session, "close", MagicMock())
def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch):
"""
Test function calling for parameter extractor.
"""
@ -114,7 +118,7 @@ def test_function_calling_parameter_extractor(setup_model_mock):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)()
db.session.close = MagicMock()
_mock_db_session_close(monkeypatch)
result = node._run()
@ -124,7 +128,7 @@ def test_function_calling_parameter_extractor(setup_model_mock):
assert result.outputs.get("__reason") == None
def test_instructions(setup_model_mock):
def test_instructions(setup_model_mock, monkeypatch):
"""
Test chat parameter extractor.
"""
@ -155,7 +159,7 @@ def test_instructions(setup_model_mock):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)()
db.session.close = MagicMock()
_mock_db_session_close(monkeypatch)
result = node._run()
@ -174,7 +178,7 @@ def test_instructions(setup_model_mock):
assert "what's the weather in SF" in prompt.get("text")
def test_chat_parameter_extractor(setup_model_mock):
def test_chat_parameter_extractor(setup_model_mock, monkeypatch):
"""
Test chat parameter extractor.
"""
@ -205,7 +209,7 @@ def test_chat_parameter_extractor(setup_model_mock):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)()
db.session.close = MagicMock()
_mock_db_session_close(monkeypatch)
result = node._run()
@ -225,7 +229,7 @@ def test_chat_parameter_extractor(setup_model_mock):
assert '<structure>\n{"type": "object"' in prompt.get("text")
def test_completion_parameter_extractor(setup_model_mock):
def test_completion_parameter_extractor(setup_model_mock, monkeypatch):
"""
Test completion parameter extractor.
"""
@ -256,7 +260,7 @@ def test_completion_parameter_extractor(setup_model_mock):
mode="completion",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)()
db.session.close = MagicMock()
_mock_db_session_close(monkeypatch)
result = node._run()
@ -350,7 +354,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock):
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
"""
Test chat parameter extractor with memory.
"""
@ -382,7 +386,7 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)()
db.session.close = MagicMock()
_mock_db_session_close(monkeypatch)
result = node._run()

View File

@ -168,6 +168,7 @@ def test_node_variable_collection_get_success(
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123")
node_variable_id = node_variable.id
_create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other")
response = test_client_with_containers.get(
@ -178,7 +179,7 @@ def test_node_variable_collection_get_success(
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert [item["id"] for item in payload["items"]] == [node_variable.id]
assert [item["id"] for item in payload["items"]] == [node_variable_id]
def test_node_variable_collection_get_invalid_node_id(
@ -377,6 +378,7 @@ def test_system_variable_collection_get(
account, tenant = create_console_account_and_tenant(db_session_with_containers)
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
variable = _create_system_variable(db_session_with_containers, app.id, account.id)
variable_id = variable.id
response = test_client_with_containers.get(
f"/console/api/apps/{app.id}/workflows/draft/system-variables",
@ -386,7 +388,7 @@ def test_system_variable_collection_get(
assert response.status_code == 200
payload = response.get_json()
assert payload is not None
assert [item["id"] for item in payload["items"]] == [variable.id]
assert [item["id"] for item in payload["items"]] == [variable_id]
def test_environment_variable_collection_get(

View File

@ -17,6 +17,8 @@ def test_get_oauth_url_successful(
test_client_with_containers: FlaskClient,
) -> None:
account, tenant = create_console_account_and_tenant(db_session_with_containers)
tenant_id = tenant.id
current_tenant_id = account.current_tenant_id
provider = MagicMock()
provider.get_authorization_url.return_value = "http://oauth.provider/auth"
@ -29,7 +31,7 @@ def test_get_oauth_url_successful(
headers=authenticate_console_client(test_client_with_containers, account),
)
assert tenant.id == account.current_tenant_id
assert tenant_id == current_tenant_id
assert response.status_code == 200
assert response.get_json() == {"data": "http://oauth.provider/auth"}
provider.get_authorization_url.assert_called_once()

View File

@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from sqlalchemy.orm import Session
from controllers.console.auth.error import (
EmailCodeError,
@ -20,13 +21,15 @@ from controllers.console.auth.forgot_password import (
ForgotPasswordSendEmailApi,
)
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from tests.test_containers_integration_tests.controllers.console.helpers import ensure_dify_setup
class TestForgotPasswordSendEmailApi:
"""Test cases for sending password reset emails."""
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask, db_session_with_containers: Session):
ensure_dify_setup(db_session_with_containers)
return flask_app_with_containers
@pytest.fixture
@ -139,7 +142,8 @@ class TestForgotPasswordCheckApi:
"""Test cases for verifying password reset codes."""
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask, db_session_with_containers: Session):
ensure_dify_setup(db_session_with_containers)
return flask_app_with_containers
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
@ -322,7 +326,8 @@ class TestForgotPasswordResetApi:
"""Test cases for resetting password with verified token."""
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask, db_session_with_containers: Session):
ensure_dify_setup(db_session_with_containers)
return flask_app_with_containers
@pytest.fixture

View File

@ -11,7 +11,7 @@ from models.enums import ConversationFromSource, MessageFileBelongsTo
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
from services.account_service import AccountService, TenantService
from services.agent_service import AgentService
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -119,16 +119,16 @@ class TestAgentService:
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "agent-chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="agent-chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)

View File

@ -9,7 +9,7 @@ from models import Account
from models.enums import ConversationFromSource, InvokeFrom
from models.model import MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -86,16 +86,16 @@ class TestAnnotationService:
tenant = account.current_tenant
# Setup app creation arguments
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
# Create app
app_service = AppService()

View File

@ -37,7 +37,7 @@ from services.app_dsl_service import (
PendingData,
_check_version_compatibility,
)
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
from tests.test_containers_integration_tests.helpers import generate_valid_password
_DEFAULT_TENANT_ID = "00000000-0000-0000-0000-000000000001"
@ -147,16 +147,16 @@ class TestAppDslService:
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
return app, account

View File

@ -1,4 +1,5 @@
import uuid
from typing import Literal
from unittest.mock import ANY, MagicMock, patch
import pytest
@ -133,7 +134,10 @@ class TestAppGenerateService:
}
def _create_test_app_and_account(
self, db_session_with_containers: Session, mock_external_service_dependencies, mode="chat"
self,
db_session_with_containers: Session,
mock_external_service_dependencies,
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = "chat",
):
"""
Helper method to create a test app and account for testing.
@ -165,20 +169,20 @@ class TestAppGenerateService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": mode,
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
"max_active_requests": 5,
}
from services.app_service import AppService, CreateAppParams
from services.app_service import AppService
# Create app with realistic data
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode=mode,
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
max_active_requests=5,
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)

View File

@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from pydantic import ValidationError
from sqlalchemy.orm import Session
from constants.model_template import default_app_templates
@ -12,7 +13,7 @@ from services.account_service import AccountService, TenantService
from tests.test_containers_integration_tests.helpers import generate_valid_password
# Delay import of AppService to avoid circular dependency
# from services.app_service import AppService
# from services.app_service import AppService, AppListParams, CreateAppParams
class TestAppService:
@ -64,34 +65,34 @@ class TestAppService:
tenant = account.current_tenant
# Setup app creation arguments
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
# Import here to avoid circular dependency
from services.app_service import AppService, CreateAppParams
app_params = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
# Create app
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
app = app_service.create_app(tenant.id, app_params, account)
# Verify app was created correctly
assert app.name == app_args["name"]
assert app.description == app_args["description"]
assert app.mode == app_args["mode"]
assert app.icon_type == app_args["icon_type"]
assert app.icon == app_args["icon"]
assert app.icon_background == app_args["icon_background"]
assert app.name == app_params.name
assert app.description == app_params.description
assert app.mode == app_params.mode
assert app.icon_type == app_params.icon_type
assert app.icon == app_params.icon
assert app.icon_background == app_params.icon_background
assert app.tenant_id == tenant.id
assert app.api_rph == app_args["api_rph"]
assert app.api_rpm == app_args["api_rpm"]
assert app.api_rph == app_params.api_rph
assert app.api_rpm == app_params.api_rpm
assert app.created_by == account.id
assert app.updated_by == account.id
assert app.status == "normal"
@ -120,7 +121,7 @@ class TestAppService:
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_service = AppService()
@ -129,20 +130,20 @@ class TestAppService:
app_modes = [v.value for v in default_app_templates]
for mode in app_modes:
app_args = {
"name": f"{fake.company()} {mode}",
"description": f"Test app for {mode} mode",
"mode": mode,
"icon_type": "emoji",
"icon": "🚀",
"icon_background": "#4ECDC4",
}
app_params = CreateAppParams(
name=f"{fake.company()} {mode}",
description=f"Test app for {mode} mode",
mode=mode,
icon_type="emoji",
icon="🚀",
icon_background="#4ECDC4",
)
app = app_service.create_app(tenant.id, app_args, account)
app = app_service.create_app(tenant.id, app_params, account)
# Verify app mode was set correctly
assert app.mode == mode
assert app.name == app_args["name"]
assert app.name == app_params.name
assert app.tenant_id == tenant.id
assert app.created_by == account.id
@ -163,20 +164,20 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_params = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🎯",
icon_background="#45B7D1",
)
app_service = AppService()
created_app = app_service.create_app(tenant.id, app_args, account)
created_app = app_service.create_app(tenant.id, app_params, account)
# Get app using the service - needs current_user mock
mock_current_user = create_autospec(Account, instance=True)
@ -211,31 +212,27 @@ class TestAppService:
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppListParams, AppService, CreateAppParams
app_service = AppService()
# Create multiple apps
app_names = [fake.company() for _ in range(5)]
for name in app_names:
app_args = {
"name": name,
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "📱",
"icon_background": "#96CEB4",
}
app_service.create_app(tenant.id, app_args, account)
app_params = CreateAppParams(
name=name,
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="📱",
icon_background="#96CEB4",
)
app_service.create_app(tenant.id, app_params, account)
# Get paginated apps
args = {
"page": 1,
"limit": 10,
"mode": "chat",
}
params = AppListParams(page=1, limit=10, mode="chat")
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
# Verify pagination results
assert paginated_apps is not None
@ -267,60 +264,47 @@ class TestAppService:
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppListParams, AppService, CreateAppParams
app_service = AppService()
# Create apps with different modes
chat_app_args = {
"name": "Chat App",
"description": "A chat application",
"mode": "chat",
"icon_type": "emoji",
"icon": "💬",
"icon_background": "#FF6B6B",
}
completion_app_args = {
"name": "Completion App",
"description": "A completion application",
"mode": "completion",
"icon_type": "emoji",
"icon": "✍️",
"icon_background": "#4ECDC4",
}
chat_app_params = CreateAppParams(
name="Chat App",
description="A chat application",
mode="chat",
icon_type="emoji",
icon="💬",
icon_background="#FF6B6B",
)
completion_app_params = CreateAppParams(
name="Completion App",
description="A completion application",
mode="completion",
icon_type="emoji",
icon="✍️",
icon_background="#4ECDC4",
)
chat_app = app_service.create_app(tenant.id, chat_app_args, account)
completion_app = app_service.create_app(tenant.id, completion_app_args, account)
chat_app = app_service.create_app(tenant.id, chat_app_params, account)
completion_app = app_service.create_app(tenant.id, completion_app_params, account)
# Test filter by mode
chat_args = {
"page": 1,
"limit": 10,
"mode": "chat",
}
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args)
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"))
assert len(chat_apps.items) == 1
assert chat_apps.items[0].mode == "chat"
# Test filter by name
name_args = {
"page": 1,
"limit": 10,
"mode": "chat",
"name": "Chat",
}
filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args)
filtered_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat")
)
assert len(filtered_apps.items) == 1
assert "Chat" in filtered_apps.items[0].name
# Test filter by created_by_me
created_by_me_args = {
"page": 1,
"limit": 10,
"mode": "completion",
"is_created_by_me": True,
}
my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args)
my_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True)
)
assert len(my_apps.items) == 1
def test_get_paginate_apps_with_tag_filters(
@ -342,34 +326,29 @@ class TestAppService:
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppListParams, AppService, CreateAppParams
app_service = AppService()
# Create an app
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🏷️",
"icon_background": "#FFEAA7",
}
app = app_service.create_app(tenant.id, app_args, account)
app_params = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🏷️",
icon_background="#FFEAA7",
)
app = app_service.create_app(tenant.id, app_params, account)
# Mock TagService to return the app ID for tag filtering
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
mock_tag_service.return_value = [app.id]
# Test with tag filter
args = {
"page": 1,
"limit": 10,
"mode": "chat",
"tag_ids": ["tag1", "tag2"],
}
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["tag1", "tag2"])
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
# Verify tag service was called
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"])
@ -383,14 +362,9 @@ class TestAppService:
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
mock_tag_service.return_value = []
args = {
"page": 1,
"limit": 10,
"mode": "chat",
"tag_ids": ["nonexistent_tag"],
}
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["nonexistent_tag"])
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
# Should return None when no apps match tag filter
assert paginated_apps is None
@ -412,20 +386,20 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_params = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🎯",
icon_background="#45B7D1",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
app = app_service.create_app(tenant.id, app_params, account)
# Store original values
original_name = app.name
@ -481,19 +455,19 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
},
CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🎯",
icon_background="#45B7D1",
),
account,
)
@ -533,19 +507,19 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
},
CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🎯",
icon_background="#45B7D1",
),
account,
)
@ -584,20 +558,20 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_params = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🎯",
icon_background="#45B7D1",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
app = app_service.create_app(tenant.id, app_params, account)
# Store original name
original_name = app.name
@ -637,20 +611,20 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🎯",
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_params = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🎯",
icon_background="#45B7D1",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
app = app_service.create_app(tenant.id, app_params, account)
# Store original values
original_icon = app.icon
@ -698,18 +672,17 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🌐",
"icon_background": "#74B9FF",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🌐",
icon_background="#74B9FF",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -758,18 +731,17 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🔌",
"icon_background": "#A29BFE",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🔌",
icon_background="#A29BFE",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -818,18 +790,17 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🔄",
"icon_background": "#FD79A8",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🔄",
icon_background="#FD79A8",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -869,18 +840,17 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🗑️",
"icon_background": "#E17055",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🗑️",
icon_background="#E17055",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -921,18 +891,17 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🧹",
"icon_background": "#00B894",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🧹",
icon_background="#00B894",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -981,18 +950,17 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "📊",
"icon_background": "#6C5CE7",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="📊",
icon_background="#6C5CE7",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -1020,18 +988,17 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🔗",
"icon_background": "#FDCB6E",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🔗",
icon_background="#FDCB6E",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -1060,18 +1027,17 @@ class TestAppService:
tenant = account.current_tenant
# Create app first
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🆔",
"icon_background": "#E84393",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🆔",
icon_background="#E84393",
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -1107,26 +1073,20 @@ class TestAppService:
password=generate_valid_password(fake),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Setup app creation arguments with invalid mode
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "invalid_mode", # Invalid mode
"icon_type": "emoji",
"icon": "",
"icon_background": "#D63031",
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import CreateAppParams
app_service = AppService()
# Attempt to create app with invalid mode
with pytest.raises(ValueError, match="invalid mode value"):
app_service.create_app(tenant.id, app_args, account)
# Attempt to create app with invalid mode - Pydantic will reject invalid literal
with pytest.raises(ValidationError):
CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="invalid_mode", # type: ignore[arg-type]
icon_type="emoji",
icon="",
icon_background="#D63031",
)
def test_get_apps_with_special_characters_in_name(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -1152,99 +1112,103 @@ class TestAppService:
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppListParams, AppService, CreateAppParams
app_service = AppService()
# Create apps with special characters in names
app_with_percent = app_service.create_app(
tenant.id,
{
"name": "App with 50% discount",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
CreateAppParams(
name="App with 50% discount",
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
),
account,
)
app_with_underscore = app_service.create_app(
tenant.id,
{
"name": "test_data_app",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
CreateAppParams(
name="test_data_app",
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
),
account,
)
app_with_backslash = app_service.create_app(
tenant.id,
{
"name": "path\\to\\app",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
CreateAppParams(
name="path\\to\\app",
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
),
account,
)
# Create app that should NOT match
app_no_match = app_service.create_app(
tenant.id,
{
"name": "100% different",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
CreateAppParams(
name="100% different",
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
),
account,
)
# Test 1: Search with % character
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "App with 50% discount"
# Test 2: Search with _ character
args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="test_data", mode="chat", page=1, limit=10)
)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "test_data_app"
# Test 3: Search with \ character
args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10)
)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "path\\to\\app"
# Test 4: Search with % should NOT match 100% (verifies escaping works)
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
paginated_apps = app_service.get_paginate_apps(
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert all("50%" in app.name for app in paginated_apps.items)

View File

@ -13,9 +13,9 @@ from uuid import uuid4
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from models.enums import DataSourceType, DocumentCreatedFrom
from models.enums import DataSourceType, DocumentCreatedFrom, SegmentStatus
from services.dataset_service import SegmentService
@ -35,13 +35,13 @@ class SegmentServiceTestDataFactory:
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
if tenant is None:
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
tenant = Tenant(name=f"tenant-{uuid4()}", status=TenantStatus.NORMAL)
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
@ -103,7 +103,7 @@ class SegmentServiceTestDataFactory:
created_by: str,
position: int = 1,
content: str = "Test content",
status: str = "completed",
status: SegmentStatus = SegmentStatus.COMPLETED,
word_count: int = 10,
tokens: int = 15,
) -> DocumentSegment:
@ -203,7 +203,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=1,
status="completed",
status=SegmentStatus.COMPLETED,
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
@ -212,7 +212,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=2,
status="indexing",
status=SegmentStatus.INDEXING,
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
@ -221,7 +221,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=3,
status="waiting",
status=SegmentStatus.WAITING,
)
# Act
@ -257,7 +257,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=1,
status="completed",
status=SegmentStatus.COMPLETED,
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
@ -266,7 +266,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=2,
status="indexing",
status=SegmentStatus.INDEXING,
)
# Act
@ -415,7 +415,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=1,
status="completed",
status=SegmentStatus.COMPLETED,
content="This is important information",
)
SegmentServiceTestDataFactory.create_segment(
@ -425,7 +425,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=2,
status="indexing",
status=SegmentStatus.INDEXING,
content="This is also important",
)
SegmentServiceTestDataFactory.create_segment(
@ -435,7 +435,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=3,
status="completed",
status=SegmentStatus.COMPLETED,
content="This is irrelevant",
)
@ -477,7 +477,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=1,
status="completed",
status=SegmentStatus.COMPLETED,
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
@ -486,7 +486,7 @@ class TestSegmentServiceGetSegments:
document_id=document.id,
created_by=owner.id,
position=2,
status="waiting",
status=SegmentStatus.WAITING,
)
# Act

View File

@ -16,6 +16,7 @@ from uuid import uuid4
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models import AccountStatus, CreatorUserRole, TenantStatus
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -25,7 +26,7 @@ from models.dataset import (
DatasetProcessRule,
DatasetQuery,
)
from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode
from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode, TagType
from models.model import Tag, TagBinding
from services.dataset_service import DatasetService, DocumentService
@ -42,11 +43,11 @@ class DatasetRetrievalTestDataFactory:
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
tenant = Tenant(
name=f"tenant-{uuid4()}",
status="normal",
status=TenantStatus.NORMAL,
)
db_session_with_containers.add_all([account, tenant])
db_session_with_containers.flush()
@ -72,7 +73,7 @@ class DatasetRetrievalTestDataFactory:
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
status=AccountStatus.ACTIVE,
)
db_session_with_containers.add(account)
db_session_with_containers.flush()
@ -130,7 +131,7 @@ class DatasetRetrievalTestDataFactory:
@staticmethod
def create_process_rule(
db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict
db_session_with_containers: Session, dataset_id: str, created_by: str, mode: ProcessRuleMode, rules: dict
) -> DatasetProcessRule:
"""Create a dataset process rule."""
process_rule = DatasetProcessRule(
@ -153,7 +154,7 @@ class DatasetRetrievalTestDataFactory:
content=content,
source=DatasetQuerySource.APP,
source_app_id=None,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=created_by,
)
db_session_with_containers.add(dataset_query)
@ -176,7 +177,7 @@ class DatasetRetrievalTestDataFactory:
"""Create a knowledge tag and bind it to the target dataset."""
tag = Tag(
tenant_id=tenant_id,
type="knowledge",
type=TagType.KNOWLEDGE,
name=f"tag-{uuid4()}",
created_by=created_by,
)

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from models.enums import ConversationFromSource, FeedbackRating, InvokeFrom
from models.model import MessageFeedback
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
from services.errors.message import (
FirstMessageNotExistsError,
LastMessageNotExistsError,
@ -103,16 +103,16 @@ class TestMessageService:
tenant = account.current_tenant
# Setup app creation arguments
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "advanced-chat", # Use advanced-chat mode to use mocked workflow
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="advanced-chat", # Use advanced-chat mode to use mocked workflow,
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
# Create app
app_service = AppService()

View File

@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
from core.ops.entities.config_entity import TracingProviderEnum
from models.model import TraceAppConfig
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
from services.ops_service import OpsService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -57,14 +57,14 @@ class TestOpsService:
app_service = AppService()
app = app_service.create_app(
tenant.id,
{
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
},
CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
),
account,
)
return app, account

View File

@ -8,7 +8,7 @@ from models import App, CreatorUserRole
from models.enums import ConversationFromSource
from models.model import EndUser, Message
from models.web import SavedMessage
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
from services.saved_message_service import SavedMessageService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -73,16 +73,16 @@ class TestSavedMessageService:
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)

View File

@ -11,7 +11,7 @@ from models.enums import ConversationFromSource
from models.model import Conversation, EndUser
from models.web import PinnedConversation
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
from services.web_conversation_service import WebConversationService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -77,16 +77,16 @@ class TestWebConversationService:
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)

View File

@ -17,7 +17,7 @@ from models.workflow import WorkflowAppLogCreatedFrom
from services.account_service import AccountService, TenantService
# Delay import of AppService to avoid circular dependency
# from services.app_service import AppService
# from services.app_service import AppService, CreateAppParams
from services.workflow_app_service import LogView, WorkflowAppService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -82,20 +82,20 @@ class TestWorkflowAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "workflow",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
# Create app with realistic data
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="workflow",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -146,20 +146,20 @@ class TestWorkflowAppService:
"""
fake = Faker()
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "workflow",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
# Import here to avoid circular dependency
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
# Create app with realistic data
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="workflow",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)

View File

@ -13,7 +13,7 @@ from models.model import (
)
from models.workflow import WorkflowRun
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
from services.workflow_run_service import WorkflowRunService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -79,16 +79,16 @@ class TestWorkflowRunService:
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="chat",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -535,13 +535,13 @@ class TestWorkflowRunService:
tenant = account.current_tenant
# Create app
app_args = {
"name": "Test App",
"mode": "chat",
"icon_type": "emoji",
"icon": "🚀",
"icon_background": "#4ECDC4",
}
app_args = CreateAppParams(
name="Test App",
mode="chat",
icon_type="emoji",
icon="🚀",
icon_background="#4ECDC4",
)
app = app_service.create_app(tenant.id, app_args, account)
# Create workflow run without node executions
@ -586,13 +586,13 @@ class TestWorkflowRunService:
tenant = account.current_tenant
# Create app
app_args = {
"name": "Test App",
"mode": "chat",
"icon_type": "emoji",
"icon": "🚀",
"icon_background": "#4ECDC4",
}
app_args = CreateAppParams(
name="Test App",
mode="chat",
icon_type="emoji",
icon="🚀",
icon_background="#4ECDC4",
)
app = app_service.create_app(tenant.id, app_args, account)
# Use invalid workflow run ID
@ -637,13 +637,13 @@ class TestWorkflowRunService:
tenant = account.current_tenant
# Create app
app_args = {
"name": "Test App",
"mode": "chat",
"icon_type": "emoji",
"icon": "🚀",
"icon_background": "#4ECDC4",
}
app_args = CreateAppParams(
name="Test App",
mode="chat",
icon_type="emoji",
icon="🚀",
icon_background="#4ECDC4",
)
app = app_service.create_app(tenant.id, app_args, account)
# Create workflow run

View File

@ -11,7 +11,7 @@ from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from models.tools import WorkflowToolProvider
from models.workflow import Workflow as WorkflowModel
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.app_service import AppService, CreateAppParams
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
from tests.test_containers_integration_tests.helpers import generate_valid_password
@ -94,16 +94,16 @@ class TestWorkflowToolManageService:
tenant = account.current_tenant
# Create app with realistic data
app_args = {
"name": fake.company(),
"description": fake.text(max_nb_chars=100),
"mode": "workflow",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
}
app_args = CreateAppParams(
name=fake.company(),
description=fake.text(max_nb_chars=100),
mode="workflow",
icon_type="emoji",
icon="🤖",
icon_background="#FF6B6B",
api_rph=100,
api_rpm=10,
)
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)

View File

@ -128,7 +128,6 @@ class TestAddDocumentToIndexTask:
for i in range(3):
segment = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -451,7 +450,6 @@ class TestAddDocumentToIndexTask:
segments = []
for i in range(3):
segment = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -630,7 +628,6 @@ class TestAddDocumentToIndexTask:
# Segment 1: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
segment1 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -650,7 +647,6 @@ class TestAddDocumentToIndexTask:
# Segment 2: Should be processed (enabled=True, status=SegmentStatus.COMPLETED)
# Note: Implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED
segment2 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -669,7 +665,6 @@ class TestAddDocumentToIndexTask:
# Segment 3: Should NOT be processed (enabled=False, status="processing")
segment3 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
@ -688,7 +683,6 @@ class TestAddDocumentToIndexTask:
# Segment 4: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
segment4 = DocumentSegment(
id=fake.uuid4(),
tenant_id=document.tenant_id,
dataset_id=dataset.id,
document_id=document.id,

View File

@ -177,7 +177,6 @@ class TestBatchCleanDocumentTask:
fake = Faker()
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=document.dataset_id,
document_id=document.id,
@ -290,10 +289,9 @@ class TestBatchCleanDocumentTask:
account = self._create_test_account(db_session_with_containers)
dataset = self._create_test_dataset(db_session_with_containers, account)
document = self._create_test_document(db_session_with_containers, dataset, account)
assert account.current_tenant
# Create segment with simple content (no image references)
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=document.dataset_id,
document_id=document.id,
@ -692,9 +690,9 @@ class TestBatchCleanDocumentTask:
# Create multiple segments for the document
segments = []
assert account.current_tenant
for i in range(3):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=document.dataset_id,
document_id=document.id,

View File

@ -220,7 +220,6 @@ class TestCleanDatasetTask:
DocumentSegment: Created document segment instance
"""
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -232,8 +231,6 @@ class TestCleanDatasetTask:
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash",
created_at=datetime.now(),
updated_at=datetime.now(),
)
db_session_with_containers.add(segment)
@ -614,7 +611,6 @@ class TestCleanDatasetTask:
"""
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -626,8 +622,6 @@ class TestCleanDatasetTask:
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash",
created_at=datetime.now(),
updated_at=datetime.now(),
)
db_session_with_containers.add(segment)
@ -729,8 +723,6 @@ class TestCleanDatasetTask:
type=DatasetMetadataType.STRING,
created_by=account.id,
)
metadata.id = str(uuid.uuid4())
metadata.created_at = datetime.now()
metadata_items.append(metadata)
# Create binding for each metadata item
@ -741,8 +733,6 @@ class TestCleanDatasetTask:
document_id=documents[i % len(documents)].id,
created_by=account.id,
)
binding.id = str(uuid.uuid4())
binding.created_at = datetime.now()
bindings.append(binding)
db_session_with_containers.add_all(metadata_items)
@ -946,7 +936,6 @@ class TestCleanDatasetTask:
long_content = "Very long content " * 100 # Long content within reasonable limits
segment_content = f"Segment with special chars: {special_content}\n{long_content}"
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -958,8 +947,6 @@ class TestCleanDatasetTask:
status=SegmentStatus.COMPLETED,
index_node_id=str(uuid.uuid4()),
index_node_hash="test_hash_" + "x" * 50, # Long hash within limits
created_at=datetime.now(),
updated_at=datetime.now(),
)
db_session_with_containers.add(segment)
db_session_with_containers.commit()

View File

@ -132,11 +132,10 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.add(document)
db_session_with_containers.flush()
document_ids.append(document.id)
assert tenant
# Create segments for each document
for j in range(2):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -297,10 +296,9 @@ class TestCleanNotionDocumentTask:
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
assert tenant
# Create test segment
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -379,12 +377,11 @@ class TestCleanNotionDocumentTask:
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
assert tenant
# Create segments without index_node_ids
segments = []
for i in range(3):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -468,11 +465,10 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.add(document)
db_session_with_containers.flush()
documents.append(document)
assert tenant
# Create segments for each document
for j in range(2):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -569,10 +565,9 @@ class TestCleanNotionDocumentTask:
segment_statuses = [SegmentStatus.WAITING, SegmentStatus.INDEXING, SegmentStatus.COMPLETED, SegmentStatus.ERROR]
segments = []
index_node_ids = []
assert tenant
for i, status in enumerate(segment_statuses):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -665,10 +660,9 @@ class TestCleanNotionDocumentTask:
)
db_session_with_containers.add(document)
db_session_with_containers.flush()
assert tenant
# Create segment
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -765,12 +759,11 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.add(document)
db_session_with_containers.flush()
documents.append(document)
assert tenant
# Create multiple segments for each document
num_segments_per_doc = 5
for j in range(num_segments_per_doc):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -875,7 +868,6 @@ class TestCleanNotionDocumentTask:
# Create segments for each document
for j in range(3):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=account.current_tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -984,11 +976,10 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.add(document)
db_session_with_containers.flush()
documents.append(document)
assert tenant
# Create segments for each document
for j in range(2):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -1093,10 +1084,9 @@ class TestCleanNotionDocumentTask:
# Create segments with metadata
segments = []
index_node_ids = []
assert tenant
for i in range(3):
segment = DocumentSegment(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,

Some files were not shown because too many files have changed in this diff Show More