mirror of
https://github.com/langgenius/dify.git
synced 2026-05-13 00:33:37 +08:00
merge main
This commit is contained in:
commit
8b664680aa
@ -63,7 +63,7 @@ pnpm analyze-component <path> --json
|
||||
|
||||
```typescript
|
||||
// ❌ Before: Complex state logic in component
|
||||
const Configuration: FC = () => {
|
||||
function Configuration() {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
|
||||
const [completionParams, setCompletionParams] = useState<FormValue>({})
|
||||
@ -85,7 +85,7 @@ export const useModelConfig = (appId: string) => {
|
||||
}
|
||||
|
||||
// Component becomes cleaner
|
||||
const Configuration: FC = () => {
|
||||
function Configuration() {
|
||||
const { modelConfig, setModelConfig } = useModelConfig(appId)
|
||||
return <div>...</div>
|
||||
}
|
||||
@ -189,8 +189,6 @@ const Template = useMemo(() => {
|
||||
|
||||
**Dify Convention**:
|
||||
- This skill is for component decomposition, not query/mutation design.
|
||||
- When refactoring data fetching, follow `web/AGENTS.md`.
|
||||
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
|
||||
- Do not introduce deprecated `useInvalid` / `useReset`.
|
||||
- Do not add thin passthrough `useQuery` wrappers during refactoring; only extract a custom hook when it truly orchestrates multiple queries/mutations or shared derived state.
|
||||
|
||||
|
||||
@ -60,8 +60,10 @@ const Template = useMemo(() => {
|
||||
**After** (complexity: ~3):
|
||||
|
||||
```typescript
|
||||
import type { ComponentType } from 'react'
|
||||
|
||||
// Define lookup table outside component
|
||||
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
|
||||
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, ComponentType<TemplateProps>>> = {
|
||||
[AppModeEnum.CHAT]: {
|
||||
[LanguagesSupported[1]]: TemplateChatZh,
|
||||
[LanguagesSupported[7]]: TemplateChatJa,
|
||||
|
||||
@ -65,10 +65,10 @@ interface ConfigurationHeaderProps {
|
||||
onPublish: () => void
|
||||
}
|
||||
|
||||
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
|
||||
function ConfigurationHeader({
|
||||
isAdvancedMode,
|
||||
onPublish,
|
||||
}) => {
|
||||
}: ConfigurationHeaderProps) {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
@ -136,7 +136,7 @@ const AppInfo = () => {
|
||||
}
|
||||
|
||||
// ✅ After: Separate view components
|
||||
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||
function AppInfoExpanded({ appDetail, onAction }: AppInfoViewProps) {
|
||||
return (
|
||||
<div className="expanded">
|
||||
{/* Clean, focused expanded view */}
|
||||
@ -144,7 +144,7 @@ const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||
)
|
||||
}
|
||||
|
||||
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
|
||||
function AppInfoCollapsed({ appDetail, onAction }: AppInfoViewProps) {
|
||||
return (
|
||||
<div className="collapsed">
|
||||
{/* Clean, focused collapsed view */}
|
||||
@ -203,12 +203,12 @@ interface AppInfoModalsProps {
|
||||
onSuccess: () => void
|
||||
}
|
||||
|
||||
const AppInfoModals: FC<AppInfoModalsProps> = ({
|
||||
function AppInfoModals({
|
||||
appDetail,
|
||||
activeModal,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}) => {
|
||||
}: AppInfoModalsProps) {
|
||||
const handleEdit = async (data) => { /* logic */ }
|
||||
const handleDuplicate = async (data) => { /* logic */ }
|
||||
const handleDelete = async () => { /* logic */ }
|
||||
@ -296,7 +296,7 @@ interface OperationItemProps {
|
||||
onAction: (id: string) => void
|
||||
}
|
||||
|
||||
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
|
||||
function OperationItem({ operation, onAction }: OperationItemProps) {
|
||||
return (
|
||||
<div className="operation-item">
|
||||
<span className="icon">{operation.icon}</span>
|
||||
@ -435,7 +435,7 @@ interface ChildProps {
|
||||
onSubmit: () => void
|
||||
}
|
||||
|
||||
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
|
||||
function Child({ value, onChange, onSubmit }: ChildProps) {
|
||||
return (
|
||||
<div>
|
||||
<input value={value} onChange={e => onChange(e.target.value)} />
|
||||
|
||||
@ -112,13 +112,13 @@ export const useModelConfig = ({
|
||||
|
||||
```typescript
|
||||
// Before: 50+ lines of state management
|
||||
const Configuration: FC = () => {
|
||||
function Configuration() {
|
||||
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
|
||||
// ... lots of related state and effects
|
||||
}
|
||||
|
||||
// After: Clean component
|
||||
const Configuration: FC = () => {
|
||||
function Configuration() {
|
||||
const {
|
||||
modelConfig,
|
||||
setModelConfig,
|
||||
@ -159,8 +159,6 @@ const Configuration: FC = () => {
|
||||
|
||||
When hook extraction touches query or mutation code, do not use this reference as the source of truth for data-layer patterns.
|
||||
|
||||
- Follow `web/AGENTS.md` first.
|
||||
- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling.
|
||||
- Do not introduce deprecated `useInvalid` / `useReset`.
|
||||
- Do not extract thin passthrough `useQuery` hooks; only extract orchestration hooks.
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS
|
||||
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
|
||||
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
|
||||
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
|
||||
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
|
||||
5. Re-check official Playwright or Cucumber docs with the available documentation tools before introducing a new framework pattern.
|
||||
|
||||
## Local Rules
|
||||
|
||||
|
||||
@ -9,18 +9,18 @@ Category: Performance
|
||||
|
||||
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
|
||||
|
||||
## Complex prop memoization
|
||||
## Complex prop stability
|
||||
|
||||
IsUrgent: True
|
||||
IsUrgent: False
|
||||
Category: Performance
|
||||
|
||||
### Description
|
||||
|
||||
Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders.
|
||||
Only require stable object, array, or map props when there is a clear reason: the child is memoized, the value participates in effect/query dependencies, the value is part of a stable-reference API contract, or profiling/local behavior shows avoidable re-renders. Do not request `useMemo` for every inline object by default; `how-to-write-component` treats memoization as a targeted optimization.
|
||||
|
||||
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
|
||||
|
||||
Wrong:
|
||||
Risky:
|
||||
|
||||
```tsx
|
||||
<HeavyComp
|
||||
@ -31,7 +31,7 @@ Wrong:
|
||||
/>
|
||||
```
|
||||
|
||||
Right:
|
||||
Better when stable identity matters:
|
||||
|
||||
```tsx
|
||||
const config = useMemo(() => ({
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
---
|
||||
name: frontend-query-mutation
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
---
|
||||
|
||||
# Frontend Query & Mutation
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
|
||||
- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough.
|
||||
- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination.
|
||||
- Keep abstractions minimal to preserve TypeScript inference.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Identify the change surface.
|
||||
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
|
||||
- Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations.
|
||||
- Read both references when a task spans contract shape and runtime behavior.
|
||||
2. Implement the smallest abstraction that fits the task.
|
||||
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
|
||||
- Extract a small shared query helper only when multiple call sites share the same extra options.
|
||||
- Create or keep feature hooks only for real orchestration or shared domain behavior.
|
||||
- When touching thin `web/service/use-*` wrappers, migrate them away when feasible.
|
||||
3. Preserve Dify conventions.
|
||||
- Keep contract inputs in `{ params, query?, body? }` shape.
|
||||
- Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
|
||||
|
||||
## Files Commonly Touched
|
||||
|
||||
- `web/contract/console/*.ts`
|
||||
- `web/contract/marketplace.ts`
|
||||
- `web/contract/router.ts`
|
||||
- `web/service/client.ts`
|
||||
- legacy `web/service/use-*.ts` files when migrating wrappers away
|
||||
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
|
||||
|
||||
## References
|
||||
|
||||
- Use `references/contract-patterns.md` for contract shape, router registration, query and mutation helpers, and anti-patterns that degrade inference.
|
||||
- Use `references/runtime-rules.md` for conditional queries, invalidation, `mutate` versus `mutateAsync`, and legacy migration rules.
|
||||
|
||||
Treat this skill as the single query and mutation entry point for Dify frontend work. Keep detailed rules in the reference files instead of duplicating them in project docs.
|
||||
@ -1,4 +0,0 @@
|
||||
interface:
|
||||
display_name: "Frontend Query & Mutation"
|
||||
short_description: "Dify TanStack Query, oRPC, and default option patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations."
|
||||
@ -1,129 +0,0 @@
|
||||
# Contract Patterns
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Intent
|
||||
- Minimal structure
|
||||
- Core workflow
|
||||
- Query usage decision rule
|
||||
- Mutation usage decision rule
|
||||
- Thin hook decision rule
|
||||
- Anti-patterns
|
||||
- Contract rules
|
||||
- Type export
|
||||
|
||||
## Intent
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Default query usage to call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract.
|
||||
- Keep abstractions minimal and preserve TypeScript inference.
|
||||
|
||||
## Minimal Structure
|
||||
|
||||
```text
|
||||
web/contract/
|
||||
├── base.ts
|
||||
├── router.ts
|
||||
├── marketplace.ts
|
||||
└── console/
|
||||
├── billing.ts
|
||||
└── ...other domains
|
||||
web/service/client.ts
|
||||
```
|
||||
|
||||
## Core Workflow
|
||||
|
||||
1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`.
|
||||
- Use `base.route({...}).output(type<...>())` as the baseline.
|
||||
- Add `.input(type<...>())` only when the request has `params`, `query`, or `body`.
|
||||
- For `GET` without input, omit `.input(...)`; do not use `.input(type<unknown>())`.
|
||||
2. Register contract in `web/contract/router.ts`.
|
||||
- Import directly from domain files and nest by API prefix.
|
||||
3. Consume from UI call sites via oRPC query utilities.
|
||||
|
||||
```typescript
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { consoleQuery } from '@/service/client'
|
||||
|
||||
const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
staleTime: 5 * 60 * 1000,
|
||||
throwOnError: true,
|
||||
select: invoice => invoice.url,
|
||||
}))
|
||||
```
|
||||
|
||||
## Query Usage Decision Rule
|
||||
|
||||
1. Default to direct `*.queryOptions(...)` usage at the call site.
|
||||
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
|
||||
3. Create or keep feature hooks only for orchestration.
|
||||
- Combine multiple queries or mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
- Prefer `web/features/{domain}/hooks/*` for feature-owned workflows.
|
||||
4. Treat `web/service/use-{domain}.ts` as legacy.
|
||||
- Do not create new thin service wrappers for oRPC contracts.
|
||||
- When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
consoleQuery.billing.invoices.queryOptions({ retry: false })
|
||||
|
||||
const invoiceQuery = useQuery({
|
||||
...invoicesBaseQueryOptions(),
|
||||
throwOnError: true,
|
||||
})
|
||||
```
|
||||
|
||||
## Mutation Usage Decision Rule
|
||||
|
||||
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
|
||||
|
||||
```typescript
|
||||
const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions())
|
||||
```
|
||||
|
||||
## Thin Hook Decision Rule
|
||||
|
||||
Remove thin hooks when they only rename a single oRPC query or mutation helper.
|
||||
Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API.
|
||||
Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`.
|
||||
|
||||
Use:
|
||||
|
||||
```typescript
|
||||
const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions())
|
||||
```
|
||||
|
||||
Keep:
|
||||
|
||||
```typescript
|
||||
const applyTagBindingsMutation = useApplyTagBindingsMutation()
|
||||
```
|
||||
|
||||
`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`.
|
||||
- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs.
|
||||
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
- Input structure: always use `{ params, query?, body? }`.
|
||||
- No-input `GET`: omit `.input(...)`; do not use `.input(type<unknown>())`.
|
||||
- Path params: use `{paramName}` in the path and match it in the `params` object.
|
||||
- Router nesting: group by API prefix, for example `/billing/*` becomes `billing: {}`.
|
||||
- No barrel files: import directly from specific files.
|
||||
- Types: import from `@/types/` and use the `type<T>()` helper.
|
||||
- Mutations: prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults, filtering, and devtools.
|
||||
|
||||
## Type Export
|
||||
|
||||
```typescript
|
||||
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
|
||||
```
|
||||
@ -1,172 +0,0 @@
|
||||
# Runtime Rules
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Conditional queries
|
||||
- oRPC default options
|
||||
- Cache invalidation
|
||||
- Key API guide
|
||||
- `mutate` vs `mutateAsync`
|
||||
- Legacy migration
|
||||
|
||||
## Conditional Queries
|
||||
|
||||
Prefer contract-shaped `queryOptions(...)`.
|
||||
When required input is missing, prefer `input: skipToken` instead of placeholder params or non-null assertions.
|
||||
Use `enabled` only for extra business gating after the input itself is already valid.
|
||||
|
||||
```typescript
|
||||
import { skipToken, useQuery } from '@tanstack/react-query'
|
||||
|
||||
// Disable the query by skipping input construction.
|
||||
function useAccessMode(appId: string | undefined) {
|
||||
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
|
||||
input: appId
|
||||
? { params: { appId } }
|
||||
: skipToken,
|
||||
}))
|
||||
}
|
||||
|
||||
// Avoid runtime-only guards that bypass type checking.
|
||||
function useBadAccessMode(appId: string | undefined) {
|
||||
return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({
|
||||
input: { params: { appId: appId! } },
|
||||
enabled: !!appId,
|
||||
}))
|
||||
}
|
||||
```
|
||||
|
||||
## oRPC Default Options
|
||||
|
||||
Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation.
|
||||
|
||||
Place defaults at the query utility creation point in `web/service/client.ts`:
|
||||
|
||||
```typescript
|
||||
export const consoleQuery = createTanstackQueryUtils(consoleClient, {
|
||||
path: ['console'],
|
||||
experimental_defaults: {
|
||||
tags: {
|
||||
create: {
|
||||
mutationOptions: {
|
||||
onSuccess: (tag, _variables, _result, context) => {
|
||||
context.client.setQueryData(
|
||||
consoleQuery.tags.list.queryKey({
|
||||
input: {
|
||||
query: {
|
||||
type: tag.type,
|
||||
},
|
||||
},
|
||||
}),
|
||||
(oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags,
|
||||
)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders.
|
||||
- Do not create a wrapper function solely to host `createTanstackQueryUtils`.
|
||||
- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`.
|
||||
- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility.
|
||||
- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation.
|
||||
|
||||
## Cache Invalidation
|
||||
|
||||
Bind shared invalidation in oRPC defaults when it is tied to a contract operation.
|
||||
Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default.
|
||||
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
|
||||
|
||||
Use:
|
||||
|
||||
- `.key()` for namespace or prefix invalidation
|
||||
- `.queryKey(...)` only for exact cache reads or writes such as `getQueryData` and `setQueryData`
|
||||
- `queryClient.invalidateQueries(...)` in mutation `onSuccess`
|
||||
|
||||
Do not use deprecated `useInvalid` from `use-base.ts`.
|
||||
|
||||
```typescript
|
||||
// Feature orchestration owns cache invalidation only when defaults are not enough.
|
||||
export const useUpdateAccessMode = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
return useMutation(consoleQuery.accessControl.updateAccessMode.mutationOptions({
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
|
||||
})
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// Component only adds UI behavior.
|
||||
updateAccessMode({ appId, mode }, {
|
||||
onSuccess: () => toast.success('...'),
|
||||
})
|
||||
|
||||
// Avoid putting invalidation knowledge in the component.
|
||||
mutate({ appId, mode }, {
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(),
|
||||
})
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Key API Guide
|
||||
|
||||
- `.key(...)`
|
||||
- Use for partial matching operations.
|
||||
- Prefer it for invalidation, refetch, and cancel patterns.
|
||||
- Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })`
|
||||
- `.queryKey(...)`
|
||||
- Use for a specific query's full key.
|
||||
- Prefer it for exact cache addressing and direct reads or writes.
|
||||
- `.mutationKey(...)`
|
||||
- Use for a specific mutation's full key.
|
||||
- Prefer it for mutation defaults registration, mutation-status filtering, and devtools grouping.
|
||||
|
||||
## `mutate` vs `mutateAsync`
|
||||
|
||||
Prefer `mutate` by default.
|
||||
Use `mutateAsync` only when Promise semantics are truly required, such as parallel mutations or sequential steps with result dependencies.
|
||||
|
||||
Rules:
|
||||
|
||||
- Event handlers should usually call `mutate(...)` with `onSuccess` or `onError`.
|
||||
- Every `await mutateAsync(...)` must be wrapped in `try/catch`.
|
||||
- Do not use `mutateAsync` when callbacks already express the flow clearly.
|
||||
|
||||
```typescript
|
||||
// Default case.
|
||||
mutation.mutate(data, {
|
||||
onSuccess: result => router.push(result.url),
|
||||
})
|
||||
|
||||
// Promise semantics are required.
|
||||
try {
|
||||
const order = await createOrder.mutateAsync(orderData)
|
||||
await confirmPayment.mutateAsync({ orderId: order.id, token })
|
||||
router.push(`/orders/${order.id}`)
|
||||
}
|
||||
catch (error) {
|
||||
toast.error(error instanceof Error ? error.message : 'Unknown error')
|
||||
}
|
||||
```
|
||||
|
||||
## Legacy Migration
|
||||
|
||||
When touching old code, migrate it toward these rules:
|
||||
|
||||
| Old pattern | New pattern |
|
||||
|---|---|
|
||||
| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration |
|
||||
| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook |
|
||||
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
|
||||
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |
|
||||
@ -5,7 +5,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com
|
||||
|
||||
# Dify Frontend Testing Skill
|
||||
|
||||
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||
This skill enables Codex to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
|
||||
|
||||
> **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`).
|
||||
|
||||
@ -24,35 +24,27 @@ Apply this skill when the user:
|
||||
**Do NOT apply** when:
|
||||
|
||||
- User is asking about backend/API tests (Python/pytest)
|
||||
- User is asking about E2E tests (Playwright/Cypress)
|
||||
- User is asking about E2E tests (Cucumber + Playwright under `e2e/`)
|
||||
- User is only asking conceptual questions without code context
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Tech Stack
|
||||
|
||||
| Tool | Version | Purpose |
|
||||
|------|---------|---------|
|
||||
| Vitest | 4.0.16 | Test runner |
|
||||
| React Testing Library | 16.0 | Component testing |
|
||||
| jsdom | - | Test environment |
|
||||
| nock | 14.0 | HTTP mocking |
|
||||
| TypeScript | 5.x | Type safety |
|
||||
|
||||
### Key Commands
|
||||
|
||||
Run these commands from `web/`. From the repository root, prefix them with `pnpm -C web`.
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pnpm test
|
||||
|
||||
# Watch mode
|
||||
pnpm test:watch
|
||||
pnpm test --watch
|
||||
|
||||
# Run specific file
|
||||
pnpm test path/to/file.spec.tsx
|
||||
|
||||
# Generate coverage report
|
||||
pnpm test:coverage
|
||||
pnpm test --coverage
|
||||
|
||||
# Analyze component complexity
|
||||
pnpm analyze-component <path>
|
||||
@ -228,7 +220,10 @@ Every test should clearly separate:
|
||||
### 2. Black-Box Testing
|
||||
|
||||
- Test observable behavior, not implementation details
|
||||
- Use semantic queries (getByRole, getByLabelText)
|
||||
- Use semantic queries (`getByRole` with accessible `name`, `getByLabelText`, `getByPlaceholderText`, `getByText`, and scoped `within(...)`)
|
||||
- Treat `getByTestId` as a last resort. If a control cannot be found by role/name, label, landmark, or dialog scope, fix the component accessibility first instead of adding or relying on `data-testid`.
|
||||
- Remove production `data-testid` attributes when semantic selectors can cover the behavior. Keep them only for non-visual mocked boundaries, editor/browser shims such as Monaco, canvas/chart output, or third-party widgets with no accessible DOM in the test environment.
|
||||
- Do not assert decorative icons by test id. Assert the named control that contains them, or mark decorative icons `aria-hidden`.
|
||||
- Avoid testing internal state directly
|
||||
- **Prefer pattern matching over hardcoded strings** in assertions:
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ See [Zustand Store Testing](#zustand-store-testing) section for full details.
|
||||
|
||||
| Location | Purpose |
|
||||
|----------|---------|
|
||||
| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `next/image`, `zustand`) |
|
||||
| `web/vitest.setup.ts` | Global mocks shared by all tests (`react-i18next`, `zustand`, clipboard, FloatingPortal, Monaco, localStorage`) |
|
||||
| `web/__mocks__/zustand.ts` | Zustand mock implementation (auto-resets stores after each test) |
|
||||
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
|
||||
| Test file | Test-specific mocks, inline with `vi.mock()` |
|
||||
@ -216,28 +216,21 @@ describe('Component', () => {
|
||||
})
|
||||
```
|
||||
|
||||
### 5. HTTP Mocking with Nock
|
||||
### 5. HTTP and `fetch` Mocking
|
||||
|
||||
```typescript
|
||||
import nock from 'nock'
|
||||
|
||||
const GITHUB_HOST = 'https://api.github.com'
|
||||
const GITHUB_PATH = '/repos/owner/repo'
|
||||
|
||||
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
|
||||
return nock(GITHUB_HOST)
|
||||
.get(GITHUB_PATH)
|
||||
.delay(delayMs)
|
||||
.reply(status, body)
|
||||
}
|
||||
|
||||
describe('GithubComponent', () => {
|
||||
afterEach(() => {
|
||||
nock.cleanAll()
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should display repo info', async () => {
|
||||
mockGithubApi(200, { name: 'dify', stars: 1000 })
|
||||
vi.mocked(globalThis.fetch).mockResolvedValueOnce(
|
||||
new Response(JSON.stringify({ name: 'dify', stars: 1000 }), {
|
||||
status: 200,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
|
||||
render(<GithubComponent />)
|
||||
|
||||
@ -247,7 +240,12 @@ describe('GithubComponent', () => {
|
||||
})
|
||||
|
||||
it('should handle API error', async () => {
|
||||
mockGithubApi(500, { message: 'Server error' })
|
||||
vi.mocked(globalThis.fetch).mockResolvedValueOnce(
|
||||
new Response(JSON.stringify({ message: 'Server error' }), {
|
||||
status: 500,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
)
|
||||
|
||||
render(<GithubComponent />)
|
||||
|
||||
@ -258,6 +256,8 @@ describe('GithubComponent', () => {
|
||||
})
|
||||
```
|
||||
|
||||
Prefer mocking `@/service/*` modules or spying on `global.fetch` / `ky` clients with deterministic responses. Do not introduce an HTTP interception dependency such as `nock` or MSW unless it is already declared in the workspace or adding it is part of the task.
|
||||
|
||||
### 6. Context Providers
|
||||
|
||||
```typescript
|
||||
@ -332,7 +332,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
1. **Don't mock Zustand store modules** - Use real stores with `setState()`
|
||||
1. Don't mock components you can import directly
|
||||
1. Don't create overly simplified mocks that miss conditional logic
|
||||
1. Don't forget to clean up nock after each test
|
||||
1. Don't leave HTTP mocks or service mock state leaking between tests
|
||||
1. Don't use `any` types in mocks without necessity
|
||||
|
||||
### Mock Decision Tree
|
||||
|
||||
@ -227,12 +227,12 @@ Failing tests compound:
|
||||
|
||||
**Fix failures immediately before proceeding.**
|
||||
|
||||
## Integration with Claude's Todo Feature
|
||||
## Integration with Codex's Todo Feature
|
||||
|
||||
When using Claude for multi-file testing:
|
||||
When using Codex for multi-file testing:
|
||||
|
||||
1. **Ask Claude to create a todo list** before starting
|
||||
1. **Request one file at a time** or ensure Claude processes incrementally
|
||||
1. **Create a todo list** before starting
|
||||
1. **Process one file at a time**
|
||||
1. **Verify each test passes** before asking for the next
|
||||
1. **Mark todos complete** as you progress
|
||||
|
||||
|
||||
71
.agents/skills/how-to-write-component/SKILL.md
Normal file
71
.agents/skills/how-to-write-component/SKILL.md
Normal file
@ -0,0 +1,71 @@
|
||||
---
|
||||
name: how-to-write-component
|
||||
description: React/TypeScript component style guide. Use when writing, refactoring, or reviewing React components, especially around props typing, state boundaries, shared local state with Jotai atoms, API types, query/mutation contracts, navigation, memoization, wrappers, and empty-state handling.
|
||||
---
|
||||
|
||||
# How To Write A Component
|
||||
|
||||
Use this as the decision guide for React/TypeScript component structure. Existing code is reference material, not automatic precedent; when it conflicts with these rules, adapt the approach instead of reproducing the violation.
|
||||
|
||||
## Core Defaults
|
||||
|
||||
- Search before adding UI, hooks, helpers, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit.
|
||||
- Group code by feature workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them.
|
||||
- Promote code to shared only when multiple verticals need the same stable primitive. Otherwise keep it local and compose shared primitives inside the owning feature.
|
||||
- Use Tailwind CSS v4.1+ rules via the `tailwind-css-rules` skill. Prefer v4 utilities, `gap`, `text-size/line-height`, `min-h-dvh`, and avoid deprecated utilities and `@apply`.
|
||||
|
||||
## Ownership
|
||||
|
||||
- Put local state, queries, mutations, handlers, and derived UI data in the lowest component that uses them. Extract a purpose-built owner component only when the logic has no natural home.
|
||||
- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data. Do not hoist a query only because it is duplicated; TanStack Query handles deduplication and cache sharing.
|
||||
- Hoist state, queries, or callbacks to a parent only when the parent consumes the data, coordinates shared loading/error/empty UI, needs one consistent snapshot, or owns a workflow spanning children.
|
||||
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow.
|
||||
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action.
|
||||
- Prefer uncontrolled DOM state and CSS variables before adding controlled props.
|
||||
|
||||
## Components, Props, And Types
|
||||
|
||||
- Type component signatures directly; do not use `FC` or `React.FC`.
|
||||
- Prefer `function` for top-level components and module helpers. Use arrow functions for local callbacks, handlers, and lambda-style APIs.
|
||||
- Prefer named exports. Use default exports only where the framework requires them, such as Next.js route files.
|
||||
- Type simple one-off props inline. Use a named `Props` type only when reused, exported, complex, or clearer.
|
||||
- Use API-generated or API-returned types at component boundaries. Keep small UI conversion helpers beside the component that needs them.
|
||||
- Name values by their domain role and backend API contract, and keep that name stable across the call chain, especially IDs like `appInstanceId`. Normalize framework or route params at the boundary.
|
||||
- Keep fallback and invariant checks at the lowest component that already handles that state; callers should pass raw values through instead of duplicating checks.
|
||||
|
||||
## Queries And Mutations
|
||||
|
||||
- Keep `web/contract/*` as the single source of truth for API shape; follow existing domain/router patterns and the `{ params, query?, body? }` input shape.
|
||||
- Consume queries directly with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`.
|
||||
- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename `queryOptions()` or `mutationOptions()`. Extract a small `queryOptions` helper only when repeated call-site options justify it.
|
||||
- Keep feature hooks for real orchestration, workflow state, or shared domain behavior.
|
||||
- For missing required query input, use `input: skipToken`; use `enabled` only for extra business gating after the input is valid.
|
||||
- Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))`; use oRPC clients as `mutationFn` only for custom flows.
|
||||
- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`; components may add UI feedback callbacks, but should not own shared invalidation rules.
|
||||
- Do not use deprecated `useInvalid` or `useReset`.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required, and wrap awaited calls in `try/catch`.
|
||||
|
||||
## Component Boundaries
|
||||
|
||||
- Use the first level below a page or tab to organize independent page sections when it adds real structure. This layer is layout/semantic first, not automatically the data owner.
|
||||
- Split deeper components by the data and state each layer actually needs. Each component should access only necessary data, and ownership should stay at the lowest consumer.
|
||||
- Keep cohesive forms, menu bodies, and one-off helpers local unless they need their own state, reuse, or semantic boundary.
|
||||
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow.
|
||||
- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment.
|
||||
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
|
||||
- Avoid shallow wrappers and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary.
|
||||
|
||||
## You Might Not Need An Effect
|
||||
|
||||
- Use Effects only to synchronize with external systems such as browser APIs, non-React widgets, subscriptions, timers, analytics that must run because the component was shown, or imperative DOM integration.
|
||||
- Do not use Effects to transform props or state for rendering. Calculate derived values during render, and use `useMemo` only when the calculation is actually expensive.
|
||||
- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known.
|
||||
- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render.
|
||||
- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary.
|
||||
- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components.
|
||||
- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow.
|
||||
|
||||
## Navigation And Performance
|
||||
|
||||
- Prefer `Link` for normal navigation. Use router APIs only for command-flow side effects such as mutation success, guarded redirects, or form submission.
|
||||
- Avoid `memo`, `useMemo`, and `useCallback` unless there is a clear performance reason.
|
||||
367
.agents/skills/tailwind-css-rules/SKILL.md
Normal file
367
.agents/skills/tailwind-css-rules/SKILL.md
Normal file
@ -0,0 +1,367 @@
|
||||
---
|
||||
name: tailwind-css-rules
|
||||
description: Tailwind CSS v4.1+ rules and best practices. Use when writing, reviewing, refactoring, or upgrading Tailwind CSS classes and styles, especially v4 utility migrations, layout spacing, typography, responsive variants, dark mode, gradients, CSS variables, and component styling.
|
||||
---
|
||||
|
||||
# Tailwind CSS Rules and Best Practices
|
||||
|
||||
## Core Principles
|
||||
|
||||
- **Always use Tailwind CSS v4.1+** - Ensure the codebase is using the latest version
|
||||
- **Do not use deprecated or removed utilities** - ALWAYS use the replacement
|
||||
- **Never use `@apply`** - Use CSS variables, the `--spacing()` function, or framework components instead
|
||||
- **Check for redundant classes** - Remove any classes that aren't necessary
|
||||
- **Group elements logically** to simplify responsive tweaks later
|
||||
|
||||
## Upgrading to Tailwind CSS v4
|
||||
|
||||
### Before Upgrading
|
||||
|
||||
- **Always read the upgrade documentation first** - Read https://tailwindcss.com/docs/upgrade-guide and https://tailwindcss.com/blog/tailwindcss-v4 before starting an upgrade.
|
||||
- Ensure the git repository is in a clean state before starting
|
||||
|
||||
### Upgrade Process
|
||||
|
||||
1. Run the upgrade command: `npx @tailwindcss/upgrade@latest` for both major and minor updates
|
||||
2. The tool will convert JavaScript config files to the new CSS format
|
||||
3. Review all changes extensively to clean up any false positives
|
||||
4. Test thoroughly across your application
|
||||
|
||||
## Breaking Changes Reference
|
||||
|
||||
### Removed Utilities (NEVER use these in v4)
|
||||
|
||||
| ❌ Deprecated | ✅ Replacement |
|
||||
| ----------------------- | ------------------------------------------------- |
|
||||
| `bg-opacity-*` | Use opacity modifiers like `bg-black/50` |
|
||||
| `text-opacity-*` | Use opacity modifiers like `text-black/50` |
|
||||
| `border-opacity-*` | Use opacity modifiers like `border-black/50` |
|
||||
| `divide-opacity-*` | Use opacity modifiers like `divide-black/50` |
|
||||
| `ring-opacity-*` | Use opacity modifiers like `ring-black/50` |
|
||||
| `placeholder-opacity-*` | Use opacity modifiers like `placeholder-black/50` |
|
||||
| `flex-shrink-*` | `shrink-*` |
|
||||
| `flex-grow-*` | `grow-*` |
|
||||
| `overflow-ellipsis` | `text-ellipsis` |
|
||||
| `decoration-slice` | `box-decoration-slice` |
|
||||
| `decoration-clone` | `box-decoration-clone` |
|
||||
|
||||
### Renamed Utilities
|
||||
|
||||
Use the v4 name when migrating code that still carries Tailwind v3 semantics. Do not blanket-replace existing v4 classes: classes such as `rounded-sm`, `shadow-sm`, `ring-1`, and `ring-2` are valid in this codebase when they intentionally represent the current design scale.
|
||||
|
||||
| ❌ v3 pattern | ✅ v4 pattern |
|
||||
| ------------------- | -------------------------------------------------- |
|
||||
| `bg-gradient-*` | `bg-linear-*` |
|
||||
| old shadow scale | verify against the current Tailwind/design scale |
|
||||
| old blur scale | verify against the current Tailwind/design scale |
|
||||
| old radius scale | use the Dify radius token mapping when applicable |
|
||||
| `outline-none` | `outline-hidden` |
|
||||
| bare `ring` utility | use an explicit ring width such as `ring-1`/`ring-2`/`ring-3` |
|
||||
|
||||
For Figma radius tokens, follow `packages/dify-ui/AGENTS.md`. For example, `--radius/xs` maps to `rounded-sm`; do not rewrite it to `rounded-xs`.
|
||||
|
||||
## Layout and Spacing Rules
|
||||
|
||||
### Flexbox and Grid Spacing
|
||||
|
||||
#### Always use gap utilities for internal spacing
|
||||
|
||||
Gap provides consistent spacing without edge cases (no extra space on last items). It's cleaner and more maintainable than margins on children.
|
||||
|
||||
```html
|
||||
<!-- ❌ Don't do this -->
|
||||
<div class="flex">
|
||||
<div class="mr-4">Item 1</div>
|
||||
<div class="mr-4">Item 2</div>
|
||||
<div>Item 3</div>
|
||||
<!-- No margin on last -->
|
||||
</div>
|
||||
|
||||
<!-- ✅ Do this instead -->
|
||||
<div class="flex gap-4">
|
||||
<div>Item 1</div>
|
||||
<div>Item 2</div>
|
||||
<div>Item 3</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
#### Gap vs Space utilities
|
||||
|
||||
- **Never use `space-x-*` or `space-y-*` in flex/grid layouts** - always use gap
|
||||
- Space utilities add margins to children and have issues with wrapped items
|
||||
- Gap works correctly with flex-wrap and all flex directions
|
||||
|
||||
```html
|
||||
<!-- ❌ Avoid space utilities in flex containers -->
|
||||
<div class="flex flex-wrap space-x-4">
|
||||
<!-- Space utilities break with wrapped items -->
|
||||
</div>
|
||||
|
||||
<!-- ✅ Use gap for consistent spacing -->
|
||||
<div class="flex flex-wrap gap-4">
|
||||
<!-- Gap works perfectly with wrapping -->
|
||||
</div>
|
||||
```
|
||||
|
||||
### General Spacing Guidelines
|
||||
|
||||
- **Prefer top and left margins** over bottom and right margins (unless conditionally rendered)
|
||||
- **Use padding on parent containers** instead of bottom margins on the last child
|
||||
- **Always use `min-h-dvh` instead of `min-h-screen`** - `min-h-screen` is buggy on mobile Safari
|
||||
- **Prefer `size-*` utilities** over separate `w-*` and `h-*` when setting equal dimensions
|
||||
- For max-widths, prefer the container scale (e.g., `max-w-2xs` over `max-w-72`)
|
||||
|
||||
## Typography Rules
|
||||
|
||||
### Line Heights
|
||||
|
||||
- **Never use `leading-*` classes** - Always use line height modifiers with text size
|
||||
- **Always use fixed line heights from the spacing scale** - Don't use named values
|
||||
|
||||
```html
|
||||
<!-- ❌ Don't do this -->
|
||||
<p class="text-base leading-7">Text with separate line height</p>
|
||||
<p class="text-lg leading-relaxed">Text with named line height</p>
|
||||
|
||||
<!-- ✅ Do this instead -->
|
||||
<p class="text-base/7">Text with line height modifier</p>
|
||||
<p class="text-lg/8">Text with specific line height</p>
|
||||
```
|
||||
|
||||
### Font Size Reference
|
||||
|
||||
Be precise with font sizes - know the actual pixel values:
|
||||
|
||||
- `text-xs` = 12px
|
||||
- `text-sm` = 14px
|
||||
- `text-base` = 16px
|
||||
- `text-lg` = 18px
|
||||
- `text-xl` = 20px
|
||||
|
||||
## Color and Opacity
|
||||
|
||||
### Opacity Modifiers
|
||||
|
||||
**Never use `bg-opacity-*`, `text-opacity-*`, etc.** - use the opacity modifier syntax:
|
||||
|
||||
```html
|
||||
<!-- ❌ Don't do this -->
|
||||
<div class="bg-red-500 bg-opacity-60">Old opacity syntax</div>
|
||||
|
||||
<!-- ✅ Do this instead -->
|
||||
<div class="bg-red-500/60">Modern opacity syntax</div>
|
||||
```
|
||||
|
||||
## Responsive Design
|
||||
|
||||
### Breakpoint Optimization
|
||||
|
||||
- **Check for redundant classes across breakpoints**
|
||||
- **Only add breakpoint variants when values change**
|
||||
|
||||
```html
|
||||
<!-- ❌ Redundant breakpoint classes -->
|
||||
<div class="px-4 md:px-4 lg:px-4">
|
||||
<!-- md:px-4 and lg:px-4 are redundant -->
|
||||
</div>
|
||||
|
||||
<!-- ✅ Efficient breakpoint usage -->
|
||||
<div class="px-4 lg:px-8">
|
||||
<!-- Only specify when value changes -->
|
||||
</div>
|
||||
```
|
||||
|
||||
## Dark Mode
|
||||
|
||||
### Dark Mode Best Practices
|
||||
|
||||
- Use the plain `dark:` variant pattern
|
||||
- Put light mode styles first, then dark mode styles
|
||||
- Ensure `dark:` variant comes before other variants
|
||||
|
||||
```html
|
||||
<!-- ✅ Correct dark mode pattern -->
|
||||
<div class="bg-white text-black dark:bg-black dark:text-white">
|
||||
<button class="hover:bg-gray-100 dark:hover:bg-gray-800">Click me</button>
|
||||
</div>
|
||||
```
|
||||
|
||||
## Gradient Utilities
|
||||
|
||||
- **ALWAYS Use `bg-linear-*` instead of `bg-gradient-*` utilities** - The gradient utilities were renamed in v4
|
||||
- Use the new `bg-radial` or `bg-radial-[<position>]` to create radial gradients
|
||||
- Use the new `bg-conic` or `bg-conic-*` to create conic gradients
|
||||
|
||||
```html
|
||||
<!-- ✅ Use the new gradient utilities -->
|
||||
<div class="h-14 bg-linear-to-br from-violet-500 to-fuchsia-500"></div>
|
||||
<div
|
||||
class="size-18 bg-radial-[at_50%_75%] from-sky-200 via-blue-400 to-indigo-900 to-90%"
|
||||
></div>
|
||||
<div
|
||||
class="size-24 bg-conic-180 from-indigo-600 via-indigo-50 to-indigo-600"
|
||||
></div>
|
||||
|
||||
<!-- ❌ Do not use bg-gradient-* utilities -->
|
||||
<div class="h-14 bg-gradient-to-br from-violet-500 to-fuchsia-500"></div>
|
||||
```
|
||||
|
||||
## Working with CSS Variables
|
||||
|
||||
### Accessing Theme Values
|
||||
|
||||
Tailwind CSS v4 exposes all theme values as CSS variables:
|
||||
|
||||
```css
|
||||
/* Access colors, and other theme values */
|
||||
.custom-element {
|
||||
background: var(--color-red-500);
|
||||
border-radius: var(--radius-lg);
|
||||
}
|
||||
```
|
||||
|
||||
### The `--spacing()` Function
|
||||
|
||||
Use the dedicated `--spacing()` function for spacing calculations:
|
||||
|
||||
```css
|
||||
.custom-class {
|
||||
margin-top: calc(100vh - --spacing(16));
|
||||
}
|
||||
```
|
||||
|
||||
### Extending theme values
|
||||
|
||||
Use CSS to extend theme values:
|
||||
|
||||
```css
|
||||
@import "tailwindcss";
|
||||
|
||||
@theme {
|
||||
--color-mint-500: oklch(0.72 0.11 178);
|
||||
}
|
||||
```
|
||||
|
||||
```html
|
||||
<div class="bg-mint-500">
|
||||
<!-- ... -->
|
||||
</div>
|
||||
```
|
||||
|
||||
## New v4 Features
|
||||
|
||||
### Container Queries
|
||||
|
||||
Use the `@container` class and size variants:
|
||||
|
||||
```html
|
||||
<article class="@container">
|
||||
<div class="flex flex-col @md:flex-row @lg:gap-8">
|
||||
<img class="w-full @md:w-48" />
|
||||
<div class="mt-4 @md:mt-0">
|
||||
<!-- Content adapts to container size -->
|
||||
</div>
|
||||
</div>
|
||||
</article>
|
||||
```
|
||||
|
||||
### Container Query Units
|
||||
|
||||
Use container-based units like `cqw` for responsive sizing:
|
||||
|
||||
```html
|
||||
<div class="@container">
|
||||
<h1 class="text-[50cqw]">Responsive to container width</h1>
|
||||
</div>
|
||||
```
|
||||
|
||||
### Text Shadows (v4.1)
|
||||
|
||||
Use text-shadow-\* utilities from text-shadow-2xs to text-shadow-lg:
|
||||
|
||||
```html
|
||||
<!-- ✅ Text shadow examples -->
|
||||
<h1 class="text-shadow-lg">Large shadow</h1>
|
||||
<p class="text-shadow-sm/50">Small shadow with opacity</p>
|
||||
```
|
||||
|
||||
### Masking (v4.1)
|
||||
|
||||
Use the new composable mask utilities for image and gradient masks:
|
||||
|
||||
```html
|
||||
<!-- ✅ Linear gradient masks on specific sides -->
|
||||
<div class="mask-t-from-50%">Top fade</div>
|
||||
<div class="mask-b-from-20% mask-b-to-80%">Bottom gradient</div>
|
||||
<div class="mask-linear-from-white mask-linear-to-black/60">
|
||||
Fade from white to black
|
||||
</div>
|
||||
|
||||
<!-- ✅ Radial gradient masks -->
|
||||
<div class="mask-radial-[100%_100%] mask-radial-from-75% mask-radial-at-left">
|
||||
Radial mask
|
||||
</div>
|
||||
```
|
||||
|
||||
## Component Patterns
|
||||
|
||||
### Avoiding Utility Inheritance
|
||||
|
||||
Don't add utilities to parents that you override in children:
|
||||
|
||||
```html
|
||||
<!-- ❌ Avoid this pattern -->
|
||||
<div class="text-center">
|
||||
<h1>Centered Heading</h1>
|
||||
<div class="text-left">Left-aligned content</div>
|
||||
</div>
|
||||
|
||||
<!-- ✅ Better approach -->
|
||||
<div>
|
||||
<h1 class="text-center">Centered Heading</h1>
|
||||
<div>Left-aligned content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### Component Extraction
|
||||
|
||||
- Extract repeated patterns into framework components, not CSS classes
|
||||
- Keep utility classes in templates/JSX
|
||||
- Use data attributes for complex state-based styling
|
||||
|
||||
## CSS Best Practices
|
||||
|
||||
### Nesting Guidelines
|
||||
|
||||
- Use nesting when styling both parent and children
|
||||
- Avoid empty parent selectors
|
||||
|
||||
```css
|
||||
/* ✅ Good nesting - parent has styles */
|
||||
.card {
|
||||
padding: --spacing(4);
|
||||
|
||||
> .card-title {
|
||||
font-weight: bold;
|
||||
}
|
||||
}
|
||||
|
||||
/* ❌ Avoid empty parents */
|
||||
ul {
|
||||
> li {
|
||||
/* Parent has no styles */
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Common Pitfalls to Avoid
|
||||
|
||||
1. **Using old opacity utilities** - Always use `/opacity` syntax like `bg-red-500/60`
|
||||
2. **Redundant breakpoint classes** - Only specify changes
|
||||
3. **Space utilities in flex/grid** - Always use gap
|
||||
4. **Leading utilities** - Use line-height modifiers like `text-sm/6`
|
||||
5. **Arbitrary values** - Use the design scale
|
||||
6. **@apply directive** - Use components or CSS variables
|
||||
7. **min-h-screen on mobile** - Use min-h-dvh
|
||||
8. **Separate width/height** - Use size utilities when equal
|
||||
9. **Arbitrary values** - Always use Tailwind's predefined scale whenever possible (e.g., use `ml-4` over `ml-[16px]`)
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -9,6 +9,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||
- uses: actions/labeler@f27b608878404679385c85cfa523b85ccb86e213 # v6.1.0
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111
|
||||
uses: anthropics/claude-code-action@476e359e6203e73dad705c8b322e333fabbd7416 # v1.0.119
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
23
Makefile
23
Makefile
@ -3,6 +3,10 @@ DOCKER_REGISTRY=langgenius
|
||||
WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
|
||||
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
|
||||
VERSION=latest
|
||||
DOCKER_DIR=docker
|
||||
DOCKER_MIDDLEWARE_ENV=$(DOCKER_DIR)/middleware.env
|
||||
DOCKER_MIDDLEWARE_ENV_EXAMPLE=$(DOCKER_DIR)/envs/middleware.env.example
|
||||
DOCKER_MIDDLEWARE_PROJECT=dify-middlewares-dev
|
||||
|
||||
# Default target - show help
|
||||
.DEFAULT_GOAL := help
|
||||
@ -17,8 +21,13 @@ dev-setup: prepare-docker prepare-web prepare-api
|
||||
# Step 1: Prepare Docker middleware
|
||||
prepare-docker:
|
||||
@echo "🐳 Setting up Docker middleware..."
|
||||
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
|
||||
@if [ ! -f "$(DOCKER_MIDDLEWARE_ENV)" ]; then \
|
||||
cp "$(DOCKER_MIDDLEWARE_ENV_EXAMPLE)" "$(DOCKER_MIDDLEWARE_ENV)"; \
|
||||
echo "Docker middleware.env created"; \
|
||||
else \
|
||||
echo "Docker middleware.env already exists"; \
|
||||
fi
|
||||
@cd $(DOCKER_DIR) && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p $(DOCKER_MIDDLEWARE_PROJECT) up -d
|
||||
@echo "✅ Docker middleware started"
|
||||
|
||||
# Step 2: Prepare web environment
|
||||
@ -39,12 +48,18 @@ prepare-api:
|
||||
# Clean dev environment
|
||||
dev-clean:
|
||||
@echo "⚠️ Stopping Docker containers..."
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
|
||||
@if [ -f "$(DOCKER_MIDDLEWARE_ENV)" ]; then \
|
||||
cd $(DOCKER_DIR) && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p $(DOCKER_MIDDLEWARE_PROJECT) down; \
|
||||
else \
|
||||
echo "Docker middleware.env does not exist, skipping compose down"; \
|
||||
fi
|
||||
@echo "🗑️ Removing volumes..."
|
||||
@rm -rf docker/volumes/db
|
||||
@rm -rf docker/volumes/mysql
|
||||
@rm -rf docker/volumes/redis
|
||||
@rm -rf docker/volumes/plugin_daemon
|
||||
@rm -rf docker/volumes/weaviate
|
||||
@rm -rf docker/volumes/sandbox/dependencies
|
||||
@rm -rf api/storage
|
||||
@echo "✅ Cleanup complete"
|
||||
|
||||
@ -132,7 +147,7 @@ help:
|
||||
@echo " make prepare-docker - Set up Docker middleware"
|
||||
@echo " make prepare-web - Set up web environment"
|
||||
@echo " make prepare-api - Set up API environment"
|
||||
@echo " make dev-clean - Stop Docker middleware containers"
|
||||
@echo " make dev-clean - Stop Docker middleware containers and remove dev data"
|
||||
@echo ""
|
||||
@echo "Backend Code Quality:"
|
||||
@echo " make format - Format code with ruff"
|
||||
|
||||
@ -34,7 +34,7 @@ TRIGGER_URL=http://localhost:5001
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=false
|
||||
ENABLE_COLLABORATION_MODE=true
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
@ -88,6 +88,10 @@ REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
|
||||
# Ops trace retry configuration
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES=60
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS=5
|
||||
|
||||
# Database configuration
|
||||
DB_TYPE=postgresql
|
||||
DB_USERNAME=postgres
|
||||
|
||||
@ -181,7 +181,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_import_modules,
|
||||
ext_orjson,
|
||||
ext_forward_refs,
|
||||
ext_set_secretkey,
|
||||
ext_compress,
|
||||
ext_code_based_extension,
|
||||
ext_database,
|
||||
@ -189,6 +188,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_migrate,
|
||||
ext_redis,
|
||||
ext_storage,
|
||||
ext_set_secretkey,
|
||||
ext_logstore, # Initialize logstore after storage, before celery
|
||||
ext_celery,
|
||||
ext_login,
|
||||
|
||||
@ -23,9 +23,9 @@ class SecurityConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
SECRET_KEY: str = Field(
|
||||
description="Secret key for secure session cookie signing."
|
||||
"Make sure you are changing this key for your deployment with a strong key."
|
||||
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
|
||||
description="Secret key for secure session cookie signing. "
|
||||
"Leave empty to let Dify generate a persistent key in the storage directory, "
|
||||
"or set a strong value via the `SECRET_KEY` environment variable.",
|
||||
default="",
|
||||
)
|
||||
|
||||
@ -1137,6 +1137,18 @@ class MultiModalTransferConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class OpsTraceConfig(BaseSettings):
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES: PositiveInt = Field(
|
||||
description="Maximum retry attempts for transient ops trace provider dispatch failures.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS: PositiveInt = Field(
|
||||
description="Delay in seconds between transient ops trace provider dispatch retry attempts.",
|
||||
default=5,
|
||||
)
|
||||
|
||||
|
||||
class CeleryBeatConfig(BaseSettings):
|
||||
CELERY_BEAT_SCHEDULER_TIME: int = Field(
|
||||
description="Interval in days for Celery Beat scheduler execution, default to 1 day",
|
||||
@ -1298,7 +1310,7 @@ class PositionConfig(BaseSettings):
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
@ -1417,6 +1429,7 @@ class FeatureConfig(
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
MultiModalTransferConfig,
|
||||
OpsTraceConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
|
||||
38
api/configs/secret_key.py
Normal file
38
api/configs/secret_key.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""SECRET_KEY persistence helpers for runtime setup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
GENERATED_SECRET_KEY_FILENAME = ".dify_secret_key"
|
||||
|
||||
|
||||
def resolve_secret_key(secret_key: str) -> str:
|
||||
"""Return an explicit SECRET_KEY or a generated key persisted in storage."""
|
||||
if secret_key:
|
||||
return secret_key
|
||||
|
||||
return _load_or_create_secret_key()
|
||||
|
||||
|
||||
def _load_or_create_secret_key() -> str:
|
||||
try:
|
||||
persisted_key = storage.load_once(GENERATED_SECRET_KEY_FILENAME).decode("utf-8").strip()
|
||||
if persisted_key:
|
||||
return persisted_key
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
generated_key = secrets.token_urlsafe(48)
|
||||
|
||||
try:
|
||||
storage.save(GENERATED_SECRET_KEY_FILENAME, f"{generated_key}\n".encode())
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"SECRET_KEY is not set and could not be generated at {GENERATED_SECRET_KEY_FILENAME}. "
|
||||
"Set SECRET_KEY explicitly or make storage writable."
|
||||
) from exc
|
||||
|
||||
return generated_key
|
||||
@ -39,7 +39,7 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
@ -478,11 +478,18 @@ class AppListApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
args_dict = args.model_dump()
|
||||
params = AppListParams(
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
mode=args.mode,
|
||||
name=args.name,
|
||||
tag_ids=args.tag_ids,
|
||||
is_created_by_me=args.is_created_by_me,
|
||||
)
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
@ -546,9 +553,17 @@ class AppListApi(Resource):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
params = CreateAppParams(
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
mode=args.mode,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||
app = app_service.create_app(current_tenant_id, params, current_user)
|
||||
app_detail = AppDetail.model_validate(app, from_attributes=True)
|
||||
return app_detail.model_dump(mode="json"), 201
|
||||
|
||||
|
||||
@ -606,63 +606,63 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args["info_list"]["data_source_type"] == "upload_file":
|
||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||
file_details = db.session.scalars(
|
||||
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
|
||||
).all()
|
||||
match args["info_list"]["data_source_type"]:
|
||||
case "upload_file":
|
||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||
file_details = db.session.scalars(
|
||||
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
|
||||
).all()
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
if file_details is None:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "notion_import":
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
case "website_crawl":
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "website_crawl":
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError("Data source type not support")
|
||||
case _:
|
||||
raise ValueError("Data source type not support")
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
|
||||
@ -369,28 +369,31 @@ class DatasetDocumentListApi(Resource):
|
||||
else:
|
||||
sort_logic = asc
|
||||
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
match sort:
|
||||
case "hit_count":
|
||||
sub_query = (
|
||||
sa.select(
|
||||
DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")
|
||||
)
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
elif sort == "created_at":
|
||||
query = query.order_by(
|
||||
sort_logic(Document.created_at),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
else:
|
||||
query = query.order_by(
|
||||
desc(Document.created_at),
|
||||
desc(Document.position),
|
||||
)
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
case "created_at":
|
||||
query = query.order_by(
|
||||
sort_logic(Document.created_at),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
case _:
|
||||
query = query.order_by(
|
||||
desc(Document.created_at),
|
||||
desc(Document.position),
|
||||
)
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
@ -106,7 +106,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
|
||||
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
|
||||
|
||||
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
|
||||
simple_account_model = get_or_create_model("TrialSimpleAccount", simple_account_fields)
|
||||
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
|
||||
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
|
||||
|
||||
|
||||
@ -136,7 +136,7 @@ def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
if not dataset.indexing_technique and not args.get("indexing_technique"):
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
|
||||
@ -32,7 +32,7 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args, extract_parent_trace_context_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
@ -166,6 +166,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
**extract_parent_trace_context_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
|
||||
@ -128,7 +128,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
|
||||
@staticmethod
|
||||
def _secret_key() -> bytes:
|
||||
return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
return dify_config.SECRET_KEY.encode()
|
||||
|
||||
def _sign_query(self, *, payload: str) -> dict[str, str]:
|
||||
timestamp = str(int(time.time()))
|
||||
|
||||
@ -15,6 +15,7 @@ from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
@ -403,8 +404,13 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
external_trace_id = None
|
||||
parent_trace_context = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
|
||||
extras = self._application_generate_entity.extras
|
||||
external_trace_id = extras.get("external_trace_id")
|
||||
parent_trace_context = extras.get("parent_trace_context")
|
||||
if isinstance(parent_trace_context, ParentTraceContext):
|
||||
parent_trace_context = parent_trace_context.model_dump(exclude_none=True)
|
||||
|
||||
trace_task = TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
@ -412,6 +418,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
conversation_id=conversation_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
parent_trace_context=parent_trace_context,
|
||||
)
|
||||
self._trace_manager.add_trace_task(trace_task)
|
||||
|
||||
|
||||
@ -35,8 +35,11 @@ class DatasourceFileManager:
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
sign = hmac.new(
|
||||
dify_config.SECRET_KEY.encode(),
|
||||
data_to_sign.encode(),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
@ -47,8 +50,11 @@ class DatasourceFileManager:
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_sign = hmac.new(
|
||||
dify_config.SECRET_KEY.encode(),
|
||||
data_to_sign.encode(),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
|
||||
@ -3,6 +3,17 @@ import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, StrictStr, ValidationError
|
||||
|
||||
|
||||
class ParentTraceContext(BaseModel):
|
||||
"""Typed parent trace context propagated from an outer workflow tool node."""
|
||||
|
||||
parent_workflow_run_id: StrictStr
|
||||
parent_node_execution_id: StrictStr | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def is_valid_trace_id(trace_id: str) -> bool:
|
||||
"""
|
||||
@ -61,6 +72,30 @@ def extract_external_trace_id_from_args(args: Mapping[str, Any]):
|
||||
return {}
|
||||
|
||||
|
||||
def extract_parent_trace_context_from_args(args: Mapping[str, Any]) -> dict[str, ParentTraceContext]:
|
||||
"""
|
||||
Extract 'parent_trace_context' from args.
|
||||
|
||||
Returns a dict suitable for use in extras when both parent identifiers exist.
|
||||
Returns an empty dict if the context is missing or incomplete.
|
||||
"""
|
||||
parent_trace_context = args.get("parent_trace_context")
|
||||
if isinstance(parent_trace_context, ParentTraceContext):
|
||||
context = parent_trace_context
|
||||
elif isinstance(parent_trace_context, Mapping):
|
||||
try:
|
||||
context = ParentTraceContext.model_validate(parent_trace_context)
|
||||
except ValidationError:
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
if context.parent_node_execution_id is None:
|
||||
return {}
|
||||
|
||||
return {"parent_trace_context": context}
|
||||
|
||||
|
||||
def get_trace_id_from_otel_context() -> str | None:
|
||||
"""
|
||||
Retrieve the current trace ID from the active OpenTelemetry trace context.
|
||||
|
||||
@ -324,9 +324,10 @@ class IndexingRunner:
|
||||
# one extract_setting is one source document
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
processing_rule = {
|
||||
"mode": tmp_processing_rule["mode"],
|
||||
"rules": tmp_processing_rule.get("rules"),
|
||||
}
|
||||
# Extract document content
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
# Cleaning and segmentation
|
||||
@ -334,7 +335,7 @@ class IndexingRunner:
|
||||
text_docs,
|
||||
current_user=None,
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
process_rule=processing_rule.to_dict(),
|
||||
process_rule=processing_rule,
|
||||
tenant_id=tenant_id,
|
||||
doc_language=doc_language,
|
||||
preview=True,
|
||||
|
||||
@ -5,6 +5,8 @@ from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
|
||||
|
||||
class BaseTraceInfo(BaseModel):
|
||||
message_id: str | None = None
|
||||
@ -51,8 +53,8 @@ class BaseTraceInfo(BaseModel):
|
||||
def resolved_parent_context(self) -> tuple[str | None, str | None]:
|
||||
"""Resolve cross-workflow parent linking from metadata.
|
||||
|
||||
Extracts typed parent IDs from the untyped ``parent_trace_context``
|
||||
metadata dict (set by tool_node when invoking nested workflows).
|
||||
Extracts typed parent IDs from the ``parent_trace_context`` metadata
|
||||
payload (set by tool_node when invoking nested workflows).
|
||||
|
||||
Returns:
|
||||
(trace_correlation_override, parent_span_id_source) where
|
||||
@ -60,13 +62,18 @@ class BaseTraceInfo(BaseModel):
|
||||
parent_span_id_source is the outer node_execution_id.
|
||||
"""
|
||||
parent_ctx = self.metadata.get("parent_trace_context")
|
||||
if not isinstance(parent_ctx, dict):
|
||||
if isinstance(parent_ctx, ParentTraceContext):
|
||||
context = parent_ctx
|
||||
elif isinstance(parent_ctx, Mapping):
|
||||
try:
|
||||
context = ParentTraceContext.model_validate(parent_ctx)
|
||||
except ValueError:
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
trace_override = parent_ctx.get("parent_workflow_run_id")
|
||||
parent_span = parent_ctx.get("parent_node_execution_id")
|
||||
return (
|
||||
trace_override if isinstance(trace_override, str) else None,
|
||||
parent_span if isinstance(parent_span, str) else None,
|
||||
context.parent_workflow_run_id,
|
||||
context.parent_node_execution_id,
|
||||
)
|
||||
|
||||
@field_serializer("start_time", "end_time")
|
||||
|
||||
22
api/core/ops/exceptions.py
Normal file
22
api/core/ops/exceptions.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""Core exceptions shared by ops trace dispatchers and trace providers.
|
||||
|
||||
Provider packages may raise these types to request generic task behavior, but
|
||||
generic Celery tasks should not import provider-specific exception classes.
|
||||
"""
|
||||
|
||||
|
||||
class RetryableTraceDispatchError(RuntimeError):
|
||||
"""Base class for transient trace dispatch failures that Celery may retry."""
|
||||
|
||||
|
||||
class PendingTraceParentContextError(RetryableTraceDispatchError):
|
||||
"""Raised when a nested trace arrives before its parent span context is available."""
|
||||
|
||||
parent_node_execution_id: str
|
||||
|
||||
def __init__(self, parent_node_execution_id: str) -> None:
|
||||
self.parent_node_execution_id = parent_node_execution_id
|
||||
super().__init__(
|
||||
"Pending trace parent context for parent_node_execution_id="
|
||||
f"{parent_node_execution_id}. Retry after the parent span context is published."
|
||||
)
|
||||
@ -16,6 +16,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
BaseTracingConfig,
|
||||
@ -52,6 +53,17 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _dump_parent_trace_context(parent_trace_context: Any) -> dict[str, str] | None:
|
||||
if isinstance(parent_trace_context, ParentTraceContext):
|
||||
return parent_trace_context.model_dump(exclude_none=True)
|
||||
if isinstance(parent_trace_context, dict):
|
||||
try:
|
||||
return ParentTraceContext.model_validate(parent_trace_context).model_dump(exclude_none=True)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
class _AppTracingConfig(TypedDict, total=False):
|
||||
enabled: bool
|
||||
tracing_provider: str | None
|
||||
@ -857,8 +869,9 @@ class TraceTask:
|
||||
}
|
||||
|
||||
parent_trace_context = self.kwargs.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context)
|
||||
if dumped_parent_trace_context:
|
||||
metadata["parent_trace_context"] = dumped_parent_trace_context
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
@ -1371,13 +1384,14 @@ class TraceTask:
|
||||
}
|
||||
|
||||
parent_trace_context = node_data.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
dumped_parent_trace_context = _dump_parent_trace_context(parent_trace_context)
|
||||
if dumped_parent_trace_context:
|
||||
metadata["parent_trace_context"] = dumped_parent_trace_context
|
||||
|
||||
message_id: str | None = None
|
||||
conversation_id = node_data.get("conversation_id")
|
||||
workflow_execution_id = node_data.get("workflow_execution_id")
|
||||
if conversation_id and workflow_execution_id and not parent_trace_context:
|
||||
if conversation_id and workflow_execution_id and not dumped_parent_trace_context:
|
||||
with Session(db.engine) as session:
|
||||
msg_id = session.scalar(
|
||||
select(Message.id).where(
|
||||
|
||||
@ -123,12 +123,15 @@ class SimplePromptTransform(PromptTransform):
|
||||
|
||||
for v in special_variable_keys:
|
||||
# support #context#, #query# and #histories#
|
||||
if v == "#context#":
|
||||
variables["#context#"] = context or ""
|
||||
elif v == "#query#":
|
||||
variables["#query#"] = query or ""
|
||||
elif v == "#histories#":
|
||||
variables["#histories#"] = histories or ""
|
||||
match v:
|
||||
case "#context#":
|
||||
variables["#context#"] = context or ""
|
||||
case "#query#":
|
||||
variables["#query#"] = query or ""
|
||||
case "#histories#":
|
||||
variables["#histories#"] = histories or ""
|
||||
case _:
|
||||
pass
|
||||
|
||||
prompt_template = prompt_template_config["prompt_template"]
|
||||
if not isinstance(prompt_template, PromptTemplateParser):
|
||||
|
||||
@ -245,6 +245,7 @@ class Jieba(BaseKeyword):
|
||||
segment = pre_segment_data["segment"]
|
||||
if pre_segment_data["keywords"]:
|
||||
segment.keywords = pre_segment_data["keywords"]
|
||||
assert segment.index_node_id
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
|
||||
)
|
||||
@ -253,6 +254,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content, keyword_number)
|
||||
segment.keywords = list(keywords)
|
||||
assert segment.index_node_id
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table or {}, segment.index_node_id, list(keywords)
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
@ -526,7 +527,7 @@ class RetrievalService:
|
||||
index_node_ids = [i for i in index_node_ids if i]
|
||||
|
||||
segment_ids: list[str] = []
|
||||
index_node_segments: list[DocumentSegment] = []
|
||||
index_node_segments: Sequence[DocumentSegment] = []
|
||||
segments: list[DocumentSegment] = []
|
||||
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
@ -568,8 +569,9 @@ class RetrievalService:
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
)
|
||||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
index_node_segments = session.execute(document_segment_stmt).scalars().all()
|
||||
for index_node_segment in index_node_segments:
|
||||
assert index_node_segment.index_node_id
|
||||
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
|
||||
|
||||
if segment_ids:
|
||||
|
||||
@ -50,6 +50,7 @@ class DatasetDocumentStore:
|
||||
|
||||
output = {}
|
||||
for document_segment in document_segments:
|
||||
assert document_segment.index_node_id
|
||||
doc_id = document_segment.index_node_id
|
||||
output[doc_id] = Document(
|
||||
page_content=document_segment.content,
|
||||
@ -103,7 +104,7 @@ class DatasetDocumentStore:
|
||||
|
||||
if not segment_document:
|
||||
max_position += 1
|
||||
|
||||
assert self._document_id
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
|
||||
@ -84,7 +84,7 @@ class IndexProcessor:
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == original_document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
|
||||
|
||||
indexing_start_at = time.perf_counter()
|
||||
# delete from vector index
|
||||
|
||||
@ -29,6 +29,7 @@ from libs import helper
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import ProcessRuleMode
|
||||
from services.account_service import AccountService
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
@ -325,7 +326,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
# update document parent mode
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode="hierarchical",
|
||||
mode=ProcessRuleMode.HIERARCHICAL,
|
||||
rules=json.dumps(
|
||||
{
|
||||
"parent_mode": parent_childs.parent_mode,
|
||||
|
||||
@ -8,6 +8,10 @@ import urllib.parse
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def _secret_key() -> bytes:
|
||||
return dify_config.SECRET_KEY.encode()
|
||||
|
||||
|
||||
def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
"""
|
||||
sign file to get a temporary url for plugin access
|
||||
@ -19,8 +23,7 @@ def sign_tool_file(tool_file_id: str, extension: str, for_external: bool = True)
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
@ -39,8 +42,7 @@ def sign_upload_file_preview_url(upload_file_id: str, extension: str) -> str:
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
@ -51,8 +53,7 @@ def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: s
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
@ -71,8 +72,7 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str,
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
query = urllib.parse.urlencode(
|
||||
{
|
||||
@ -92,8 +92,7 @@ def verify_plugin_file_signature(
|
||||
"""Verify the signature used by the plugin-facing file upload endpoint."""
|
||||
|
||||
data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_sign = hmac.new(_secret_key(), data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
if sign != recalculated_encoded_sign:
|
||||
|
||||
@ -51,8 +51,11 @@ class ToolFileManager:
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
sign = hmac.new(
|
||||
dify_config.SECRET_KEY.encode(),
|
||||
data_to_sign.encode(),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
@ -63,8 +66,11 @@ class ToolFileManager:
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_sign = hmac.new(
|
||||
dify_config.SECRET_KEY.encode(),
|
||||
data_to_sign.encode(),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
|
||||
@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import ParentTraceContext, extract_parent_trace_context_from_args
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
@ -36,6 +37,8 @@ class WorkflowTool(Tool):
|
||||
Workflow tool.
|
||||
"""
|
||||
|
||||
_parent_trace_context: ParentTraceContext | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_app_id: str,
|
||||
@ -54,6 +57,7 @@ class WorkflowTool(Tool):
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.label = label
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
self._parent_trace_context = None
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
|
||||
@ -94,11 +98,17 @@ class WorkflowTool(Tool):
|
||||
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
generator_args: dict[str, Any] = {"inputs": tool_parameters, "files": files}
|
||||
if self._parent_trace_context:
|
||||
generator_args.update(
|
||||
extract_parent_trace_context_from_args({"parent_trace_context": self._parent_trace_context})
|
||||
)
|
||||
|
||||
result = generator.generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args={"inputs": tool_parameters, "files": files},
|
||||
args=generator_args,
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
@ -194,7 +204,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
:return: the new tool
|
||||
"""
|
||||
return self.__class__(
|
||||
forked = self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
workflow_app_id=self.workflow_app_id,
|
||||
@ -204,6 +214,24 @@ class WorkflowTool(Tool):
|
||||
version=self.version,
|
||||
label=self.label,
|
||||
)
|
||||
forked._parent_trace_context = self._parent_trace_context.model_copy() if self._parent_trace_context else None
|
||||
return forked
|
||||
|
||||
def set_parent_trace_context(
|
||||
self,
|
||||
*,
|
||||
parent_workflow_run_id: str,
|
||||
parent_node_execution_id: str,
|
||||
) -> None:
|
||||
"""Attach outer workflow trace context without exposing it as tool input."""
|
||||
self._parent_trace_context = ParentTraceContext(
|
||||
parent_workflow_run_id=parent_workflow_run_id,
|
||||
parent_node_execution_id=parent_node_execution_id,
|
||||
)
|
||||
|
||||
def clear_parent_trace_context(self) -> None:
|
||||
"""Remove parent trace context before invoking this tool outside a nested workflow."""
|
||||
self._parent_trace_context = None
|
||||
|
||||
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
|
||||
"""
|
||||
|
||||
@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
@ -358,6 +359,7 @@ class _WorkflowToolRuntimeBinding:
|
||||
|
||||
tool: Tool
|
||||
conversation_id: str | None = None
|
||||
parent_trace_context: ParentTraceContext | None = None
|
||||
|
||||
|
||||
class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
@ -378,6 +380,7 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
node_id: str,
|
||||
node_data: ToolNodeData,
|
||||
variable_pool,
|
||||
node_execution_id: str | None = None,
|
||||
) -> ToolRuntimeHandle:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
@ -397,7 +400,25 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
conversation_id = (
|
||||
None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID)
|
||||
)
|
||||
return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id))
|
||||
parent_trace_context: ParentTraceContext | None = None
|
||||
if self._is_workflow_tool_provider(node_data):
|
||||
outer_workflow_run_id = (
|
||||
None
|
||||
if variable_pool is None
|
||||
else get_system_text(variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID)
|
||||
)
|
||||
if isinstance(outer_workflow_run_id, str) and isinstance(node_execution_id, str):
|
||||
parent_trace_context = ParentTraceContext(
|
||||
parent_workflow_run_id=outer_workflow_run_id,
|
||||
parent_node_execution_id=node_execution_id,
|
||||
)
|
||||
return ToolRuntimeHandle(
|
||||
raw=_WorkflowToolRuntimeBinding(
|
||||
tool=tool_runtime,
|
||||
conversation_id=conversation_id,
|
||||
parent_trace_context=parent_trace_context,
|
||||
)
|
||||
)
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
@ -421,6 +442,13 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
runtime_binding = self._binding_from_handle(tool_runtime)
|
||||
tool = runtime_binding.tool
|
||||
callback = DifyWorkflowCallbackHandler()
|
||||
if runtime_binding.parent_trace_context and hasattr(tool, "set_parent_trace_context"):
|
||||
tool.set_parent_trace_context(
|
||||
parent_workflow_run_id=runtime_binding.parent_trace_context.parent_workflow_run_id,
|
||||
parent_node_execution_id=runtime_binding.parent_trace_context.parent_node_execution_id,
|
||||
)
|
||||
elif hasattr(tool, "clear_parent_trace_context"):
|
||||
tool.clear_parent_trace_context()
|
||||
|
||||
try:
|
||||
messages = ToolEngine.generic_invoke(
|
||||
@ -513,6 +541,10 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
credential_id=node_data.credential_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_workflow_tool_provider(node_data: ToolNodeData) -> bool:
|
||||
return node_data.provider_type.value == CoreToolProviderType.WORKFLOW.value
|
||||
|
||||
def _adapt_messages(
|
||||
self,
|
||||
messages: Generator[CoreToolInvokeMessage, None, None],
|
||||
|
||||
@ -5,6 +5,7 @@ import threading
|
||||
from flask import Response
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.admin import admin_required
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
@ -25,6 +26,7 @@ def init_app(app: DifyApp):
|
||||
)
|
||||
|
||||
@app.route("/threads")
|
||||
@admin_required
|
||||
def threads(): # pyright: ignore[reportUnusedFunction]
|
||||
num_threads = threading.active_count()
|
||||
threads = threading.enumerate()
|
||||
@ -50,6 +52,7 @@ def init_app(app: DifyApp):
|
||||
}
|
||||
|
||||
@app.route("/db-pool-stat")
|
||||
@admin_required
|
||||
def pool_stat(): # pyright: ignore[reportUnusedFunction]
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
@ -1,6 +1,13 @@
|
||||
from configs import dify_config
|
||||
from configs.secret_key import resolve_secret_key
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
app.secret_key = dify_config.SECRET_KEY
|
||||
def init_app(app: DifyApp) -> None:
|
||||
"""Resolve SECRET_KEY after config loading and before session/login setup."""
|
||||
secret_key = dify_config.SECRET_KEY
|
||||
if not secret_key:
|
||||
secret_key = resolve_secret_key(secret_key)
|
||||
dify_config.SECRET_KEY = secret_key
|
||||
app.config["SECRET_KEY"] = secret_key
|
||||
app.secret_key = secret_key
|
||||
|
||||
@ -1,19 +1,22 @@
|
||||
"""Workflow comment models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from models.base import TypeBase
|
||||
|
||||
from .account import Account
|
||||
from .base import Base, gen_uuidv7_string
|
||||
from .base import gen_uuidv7_string
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class WorkflowComment(Base):
|
||||
class WorkflowComment(TypeBase):
|
||||
"""Workflow comment model for canvas commenting functionality.
|
||||
|
||||
Comments are associated with apps rather than specific workflow versions,
|
||||
@ -42,27 +45,33 @@ class WorkflowComment(Base):
|
||||
Index("workflow_comments_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(sa.Float)
|
||||
position_y: Mapped[float] = mapped_column(sa.Float)
|
||||
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime)
|
||||
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(sa.DateTime, default=None)
|
||||
resolved_by: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||
|
||||
resolved: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False)
|
||||
# Relationships
|
||||
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
|
||||
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
|
||||
replies: Mapped[list[WorkflowCommentReply]] = relationship(
|
||||
lambda: WorkflowCommentReply, back_populates="comment", cascade="all, delete-orphan", init=False
|
||||
)
|
||||
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
|
||||
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
|
||||
mentions: Mapped[list[WorkflowCommentMention]] = relationship(
|
||||
lambda: WorkflowCommentMention, back_populates="comment", cascade="all, delete-orphan", init=False
|
||||
)
|
||||
|
||||
@property
|
||||
@ -131,7 +140,7 @@ class WorkflowComment(Base):
|
||||
return participants
|
||||
|
||||
|
||||
class WorkflowCommentReply(Base):
|
||||
class WorkflowCommentReply(TypeBase):
|
||||
"""Workflow comment reply model.
|
||||
|
||||
Attributes:
|
||||
@ -149,18 +158,24 @@ class WorkflowCommentReply(Base):
|
||||
Index("comment_replies_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
content: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
|
||||
comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="replies", init=False)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
@ -174,7 +189,7 @@ class WorkflowCommentReply(Base):
|
||||
self._created_by_account_cache = account
|
||||
|
||||
|
||||
class WorkflowCommentMention(Base):
|
||||
class WorkflowCommentMention(TypeBase):
|
||||
"""Workflow comment mention model.
|
||||
|
||||
Mentions are only for internal accounts since end users
|
||||
@ -194,18 +209,18 @@ class WorkflowCommentMention(Base):
|
||||
Index("comment_mentions_user_idx", "mentioned_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=gen_uuidv7_string)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default_factory=gen_uuidv7_string, init=False)
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, sa.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
reply_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
reply_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID, sa.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True, default=None
|
||||
)
|
||||
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
|
||||
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
|
||||
comment: Mapped[WorkflowComment] = relationship(lambda: WorkflowComment, back_populates="mentions", init=False)
|
||||
reply: Mapped[WorkflowCommentReply | None] = relationship(lambda: WorkflowCommentReply, init=False)
|
||||
|
||||
@property
|
||||
def mentioned_user_account(self):
|
||||
|
||||
@ -8,10 +8,9 @@ import os
|
||||
import pickle
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, ClassVar, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -441,23 +440,27 @@ class Dataset(Base):
|
||||
return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node"
|
||||
|
||||
|
||||
class DatasetProcessRule(Base): # bug
|
||||
class DatasetProcessRule(TypeBase):
|
||||
__tablename__ = "dataset_process_rules"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
|
||||
sa.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'"))
|
||||
rules = mapped_column(LongText, nullable=True)
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
mode: Mapped[ProcessRuleMode] = mapped_column(
|
||||
EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'")
|
||||
)
|
||||
rules: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
MODES = ["automatic", "custom", "hierarchical"]
|
||||
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
|
||||
AUTOMATIC_RULES: AutomaticRulesConfig = {
|
||||
AUTOMATIC_RULES: ClassVar[AutomaticRulesConfig] = {
|
||||
"pre_processing_rules": [
|
||||
{"id": "remove_extra_spaces", "enabled": True},
|
||||
{"id": "remove_urls_emails", "enabled": False},
|
||||
@ -827,7 +830,7 @@ class Document(Base):
|
||||
)
|
||||
|
||||
|
||||
class DocumentSegment(Base):
|
||||
class DocumentSegment(TypeBase):
|
||||
__tablename__ = "document_segments"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="document_segment_pkey"),
|
||||
@ -840,35 +843,40 @@ class DocumentSegment(Base):
|
||||
)
|
||||
|
||||
# initial fields
|
||||
id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id = mapped_column(StringUUID, nullable=False)
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default_factory=lambda: str(uuid4()), init=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position: Mapped[int]
|
||||
content = mapped_column(LongText, nullable=False)
|
||||
answer = mapped_column(LongText, nullable=True)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
word_count: Mapped[int]
|
||||
tokens: Mapped[int]
|
||||
|
||||
# indexing fields
|
||||
keywords = mapped_column(sa.JSON, nullable=True)
|
||||
index_node_id = mapped_column(String(255), nullable=True)
|
||||
index_node_hash = mapped_column(String(255), nullable=True)
|
||||
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# basic fields
|
||||
# indexing fields
|
||||
index_node_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
index_node_hash: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"), default=True)
|
||||
answer: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
keywords: Mapped[Any] = mapped_column(sa.JSON, nullable=True, default=None)
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
disabled_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
status: Mapped[SegmentStatus] = mapped_column(
|
||||
EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"), default=SegmentStatus.WAITING
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
error: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'"))
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
error = mapped_column(LongText, nullable=True)
|
||||
stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
@ -895,7 +903,7 @@ class DocumentSegment(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def child_chunks(self) -> Sequence[Any]:
|
||||
def child_chunks(self):
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
@ -910,7 +918,7 @@ class DocumentSegment(Base):
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
def get_child_chunks(self) -> Sequence[Any]:
|
||||
def get_child_chunks(self):
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
@ -941,7 +949,7 @@ class DocumentSegment(Base):
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
@ -958,7 +966,7 @@ class DocumentSegment(Base):
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
@ -977,7 +985,7 @@ class DocumentSegment(Base):
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
@ -1015,7 +1023,7 @@ class DocumentSegment(Base):
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
secret_key = dify_config.SECRET_KEY.encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
|
||||
@ -13786,6 +13786,14 @@ Tag type
|
||||
| unit | string | | No |
|
||||
| variable | string | | No |
|
||||
|
||||
#### TrialSimpleAccount
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| email | string | | No |
|
||||
| id | string | | No |
|
||||
| name | string | | No |
|
||||
|
||||
#### TrialSite
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
@ -13829,7 +13837,7 @@ Tag type
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| conversation_variables | [ [TrialConversationVariable](#trialconversationvariable) ] | | No |
|
||||
| created_at | object | | No |
|
||||
| created_by | [SimpleAccount](#simpleaccount) | | No |
|
||||
| created_by | [TrialSimpleAccount](#trialsimpleaccount) | | No |
|
||||
| environment_variables | [ object ] | | No |
|
||||
| features | object | | No |
|
||||
| graph | object | | No |
|
||||
@ -13840,7 +13848,7 @@ Tag type
|
||||
| rag_pipeline_variables | [ [TrialPipelineVariable](#trialpipelinevariable) ] | | No |
|
||||
| tool_published | boolean | | No |
|
||||
| updated_at | object | | No |
|
||||
| updated_by | [SimpleAccount](#simpleaccount) | | No |
|
||||
| updated_by | [TrialSimpleAccount](#trialsimpleaccount) | | No |
|
||||
| version | string | | No |
|
||||
|
||||
#### TrialWorkflowPartial
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, Protocol, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import (
|
||||
@ -19,7 +21,7 @@ from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.semconv.attributes import exception_attributes
|
||||
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
|
||||
from opentelemetry.trace import Span, Status, StatusCode, get_current_span, set_span_in_context, use_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@ -36,16 +38,106 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.exceptions import PendingTraceParentContextError
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_trace_arize_phoenix.config import ArizeConfig, PhoenixConfig
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.enums import WorkflowNodeExecutionStatus
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This parent-span carrier store is intentionally Phoenix-local for the current
|
||||
# nested workflow tracing feature. If other trace providers need the same
|
||||
# cross-task parent restoration behavior, move the storage and retry signaling
|
||||
# behind a core trace coordination interface instead of duplicating it here.
|
||||
_PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS = 300
|
||||
_TRACEPARENT_PATTERN = re.compile(
|
||||
r"^(?P<version>[0-9a-f]{2})-(?P<trace_id>[0-9a-f]{32})-(?P<span_id>[0-9a-f]{16})-(?P<flags>[0-9a-f]{2})$"
|
||||
)
|
||||
|
||||
|
||||
def _phoenix_parent_span_redis_key(parent_node_execution_id: str) -> str:
|
||||
"""Build the Redis key that stores a restorable Phoenix parent span carrier."""
|
||||
return f"trace:phoenix:parent_span:{parent_node_execution_id}"
|
||||
|
||||
|
||||
def _publish_parent_span_context(parent_node_execution_id: str, carrier: Mapping[str, str]) -> None:
|
||||
"""Persist a tracecontext carrier so nested workflow spans can restore the tool span parent."""
|
||||
redis_client.setex(
|
||||
_phoenix_parent_span_redis_key(parent_node_execution_id),
|
||||
_PHOENIX_PARENT_SPAN_CONTEXT_TTL_SECONDS,
|
||||
safe_json_dumps(dict(carrier)),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_published_parent_span_context(parent_node_execution_id: str) -> dict[str, str]:
|
||||
"""Load a previously published tool-span carrier for nested workflow parenting."""
|
||||
raw_carrier = redis_client.get(_phoenix_parent_span_redis_key(parent_node_execution_id))
|
||||
if raw_carrier is None:
|
||||
raise PendingTraceParentContextError(parent_node_execution_id)
|
||||
|
||||
if isinstance(raw_carrier, bytes):
|
||||
raw_carrier = raw_carrier.decode("utf-8")
|
||||
|
||||
carrier = json.loads(raw_carrier)
|
||||
if not isinstance(carrier, dict):
|
||||
raise ValueError(
|
||||
"Phoenix parent span context must be stored as a JSON object: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
normalized_carrier = {str(key): str(value) for key, value in carrier.items()}
|
||||
if not normalized_carrier:
|
||||
raise ValueError(
|
||||
f"Phoenix parent span context payload is empty: parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
traceparent = normalized_carrier.get("traceparent")
|
||||
if not isinstance(traceparent, str):
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload is missing traceparent: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
traceparent_match = _TRACEPARENT_PATTERN.fullmatch(traceparent)
|
||||
if traceparent_match is None:
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload has invalid traceparent format: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
if traceparent_match.group("version") == "ff":
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload has unsupported traceparent version: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
if traceparent_match.group("trace_id") == "0" * 32:
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload has zero trace_id in traceparent: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
if traceparent_match.group("span_id") == "0" * 16:
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload has zero span_id in traceparent: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
extracted_context = TraceContextTextMapPropagator().extract(carrier=normalized_carrier)
|
||||
extracted_span_context = get_current_span(extracted_context).get_span_context()
|
||||
if not extracted_span_context.is_valid or not extracted_span_context.is_remote:
|
||||
raise ValueError(
|
||||
"Phoenix parent span context payload could not be restored into a valid parent span: "
|
||||
f"parent_node_execution_id={parent_node_execution_id}"
|
||||
)
|
||||
|
||||
return normalized_carrier
|
||||
|
||||
|
||||
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
|
||||
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
|
||||
@ -177,6 +269,246 @@ def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
|
||||
return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN)
|
||||
|
||||
|
||||
def _resolve_workflow_session_id(trace_info: WorkflowTraceInfo) -> str:
|
||||
"""Resolve the workflow session ID for Phoenix workflow spans."""
|
||||
if trace_info.conversation_id:
|
||||
return trace_info.conversation_id
|
||||
|
||||
parent_workflow_run_id, _ = _resolve_workflow_parent_context(trace_info)
|
||||
if parent_workflow_run_id:
|
||||
return parent_workflow_run_id
|
||||
|
||||
return trace_info.workflow_run_id
|
||||
|
||||
|
||||
def _resolve_workflow_parent_context(trace_info: BaseTraceInfo) -> tuple[str | None, str | None]:
|
||||
"""Expose the typed parent context already resolved on the trace info."""
|
||||
return trace_info.resolved_parent_context
|
||||
|
||||
|
||||
def _resolve_workflow_root_trace_id(trace_info: WorkflowTraceInfo) -> str:
|
||||
"""Resolve the canonical root trace ID for Phoenix workflow spans."""
|
||||
trace_correlation_override, _ = _resolve_workflow_parent_context(trace_info)
|
||||
return trace_correlation_override or trace_info.resolved_trace_id or trace_info.workflow_run_id
|
||||
|
||||
|
||||
class _NodeExecutionIdentityLike(Protocol):
|
||||
@property
|
||||
def node_execution_id(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def node_id(self) -> str: ...
|
||||
|
||||
@property
|
||||
def predecessor_node_id(self) -> str | None: ...
|
||||
|
||||
|
||||
class _NodeExecutionLike(_NodeExecutionIdentityLike, Protocol):
|
||||
@property
|
||||
def id(self) -> str: ...
|
||||
|
||||
@property
|
||||
def node_type(self) -> str: ...
|
||||
|
||||
@property
|
||||
def title(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Any] | None: ...
|
||||
|
||||
@property
|
||||
def process_data(self) -> Mapping[str, Any] | None: ...
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Any] | None: ...
|
||||
|
||||
@property
|
||||
def status(self) -> WorkflowNodeExecutionStatus: ...
|
||||
|
||||
@property
|
||||
def error(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def elapsed_time(self) -> float | None: ...
|
||||
|
||||
@property
|
||||
def metadata(self) -> Mapping[Any, Any] | None: ...
|
||||
|
||||
@property
|
||||
def created_at(self) -> datetime | None: ...
|
||||
|
||||
|
||||
_PHOENIX_STRUCTURED_NODE_TYPES = frozenset({"start", "end", "loop", "iteration"})
|
||||
|
||||
|
||||
def _resolve_workflow_span_name(trace_info: WorkflowTraceInfo) -> str:
|
||||
"""Resolve the Phoenix workflow span display name."""
|
||||
workflow_run_id = trace_info.workflow_run_id.strip() if trace_info.workflow_run_id else ""
|
||||
if workflow_run_id:
|
||||
return f"{TraceTaskName.WORKFLOW_TRACE.value}_{workflow_run_id}"
|
||||
return TraceTaskName.WORKFLOW_TRACE.value
|
||||
|
||||
|
||||
def _build_node_title_by_id(trace_info: WorkflowTraceInfo) -> dict[str, str]:
|
||||
"""Build an authoritative node-title index from the persisted workflow graph."""
|
||||
workflow_data = trace_info.workflow_data
|
||||
workflow_graph = getattr(workflow_data, "graph_dict", None)
|
||||
if not isinstance(workflow_graph, Mapping):
|
||||
workflow_graph = workflow_data.get("graph") if isinstance(workflow_data, Mapping) else None
|
||||
if not isinstance(workflow_graph, Mapping):
|
||||
return {}
|
||||
|
||||
graph_nodes = workflow_graph.get("nodes")
|
||||
if not isinstance(graph_nodes, Sequence):
|
||||
return {}
|
||||
|
||||
node_title_by_id: dict[str, str] = {}
|
||||
for graph_node in graph_nodes:
|
||||
if not isinstance(graph_node, Mapping):
|
||||
continue
|
||||
node_id = graph_node.get("id")
|
||||
node_data = graph_node.get("data")
|
||||
if not isinstance(node_id, str) or not isinstance(node_data, Mapping):
|
||||
continue
|
||||
node_title = node_data.get("title")
|
||||
if isinstance(node_title, str) and node_title.strip():
|
||||
node_title_by_id[node_id] = node_title.strip()
|
||||
|
||||
return node_title_by_id
|
||||
|
||||
|
||||
def _resolve_workflow_node_span_name(
|
||||
node_execution: _NodeExecutionLike,
|
||||
node_title_by_id: Mapping[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Resolve the Phoenix workflow node span display name."""
|
||||
node_type = str(node_execution.node_type or "")
|
||||
graph_node_title = None
|
||||
if node_title_by_id is not None and isinstance(node_execution.node_id, str):
|
||||
graph_node_title = node_title_by_id.get(node_execution.node_id)
|
||||
|
||||
node_title = graph_node_title or (node_execution.title.strip() if isinstance(node_execution.title, str) else "")
|
||||
if node_title:
|
||||
return f"{node_type}_{node_title}"
|
||||
return node_type
|
||||
|
||||
|
||||
def _get_node_execution_id(node_execution: _NodeExecutionIdentityLike) -> str:
|
||||
"""Return the stable execution identifier for a workflow node execution."""
|
||||
return str(getattr(node_execution, "id", None) or node_execution.node_execution_id)
|
||||
|
||||
|
||||
def _build_execution_id_by_node_id(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]:
|
||||
"""Index unique workflow graph node ids by execution id.
|
||||
|
||||
This Phoenix-local hierarchy reconstruction intentionally drops ambiguous
|
||||
node ids instead of guessing based on repository order. That keeps parent
|
||||
selection deterministic until upstream tracing exposes explicit parent span
|
||||
data for repeated executions.
|
||||
"""
|
||||
execution_id_by_node_id: dict[str, str] = {}
|
||||
ambiguous_node_ids: set[str] = set()
|
||||
|
||||
for node_execution in node_executions:
|
||||
node_id = node_execution.node_id
|
||||
if not isinstance(node_id, str):
|
||||
continue
|
||||
execution_id = _get_node_execution_id(node_execution)
|
||||
|
||||
if node_id in ambiguous_node_ids:
|
||||
continue
|
||||
|
||||
existing_execution_id = execution_id_by_node_id.get(node_id)
|
||||
if existing_execution_id is None:
|
||||
execution_id_by_node_id[node_id] = execution_id
|
||||
continue
|
||||
|
||||
if existing_execution_id != execution_id:
|
||||
ambiguous_node_ids.add(node_id)
|
||||
execution_id_by_node_id.pop(node_id, None)
|
||||
|
||||
return execution_id_by_node_id
|
||||
|
||||
|
||||
def _build_graph_parent_index(node_executions: Sequence[_NodeExecutionIdentityLike]) -> dict[str, str]:
|
||||
"""Build an execution-id parent index from predecessor node ids."""
|
||||
execution_id_by_node_id = _build_execution_id_by_node_id(node_executions)
|
||||
graph_parent_index: dict[str, str] = {}
|
||||
|
||||
for node_execution in node_executions:
|
||||
predecessor_node_id = node_execution.predecessor_node_id
|
||||
if not isinstance(predecessor_node_id, str):
|
||||
continue
|
||||
|
||||
predecessor_execution_id = execution_id_by_node_id.get(predecessor_node_id)
|
||||
if predecessor_execution_id is not None:
|
||||
execution_id = _get_node_execution_id(node_execution)
|
||||
graph_parent_index[execution_id] = predecessor_execution_id
|
||||
|
||||
return graph_parent_index
|
||||
|
||||
|
||||
def _resolve_structured_parent_execution_id(
|
||||
node_execution: object, execution_id_by_node_id: Mapping[str, str]
|
||||
) -> str | None:
|
||||
"""Resolve Phoenix-local structured parents from loop/iteration node ids.
|
||||
|
||||
Any execution carrying ``iteration_id`` or ``loop_id`` belongs to an
|
||||
enclosing structured node. When predecessor node ids are ambiguous because
|
||||
the graph node repeats inside that structure, Phoenix can still keep the
|
||||
child span under the enclosing loop/iteration span without relying on
|
||||
execution-order heuristics.
|
||||
"""
|
||||
execution_metadata = getattr(node_execution, "execution_metadata_dict", None)
|
||||
if not isinstance(execution_metadata, Mapping):
|
||||
execution_metadata = getattr(node_execution, "metadata", None)
|
||||
if not isinstance(execution_metadata, Mapping):
|
||||
execution_metadata = {}
|
||||
|
||||
for enclosing_node_id in (
|
||||
getattr(node_execution, "iteration_id", None),
|
||||
getattr(node_execution, "loop_id", None),
|
||||
execution_metadata.get("iteration_id"),
|
||||
execution_metadata.get("loop_id"),
|
||||
):
|
||||
if not isinstance(enclosing_node_id, str):
|
||||
continue
|
||||
|
||||
enclosing_execution_id = execution_id_by_node_id.get(enclosing_node_id)
|
||||
if enclosing_execution_id is not None:
|
||||
return enclosing_execution_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_node_parent(
|
||||
execution_id: str,
|
||||
predecessor_execution_id: str | None,
|
||||
structured_parent_execution_id: str | None,
|
||||
span_by_execution_id: Mapping[str, Span],
|
||||
graph_parent_index: Mapping[str, str],
|
||||
workflow_span: Span,
|
||||
) -> Span:
|
||||
"""Resolve the parent span for a workflow node execution."""
|
||||
if predecessor_execution_id is not None:
|
||||
predecessor_span = span_by_execution_id.get(predecessor_execution_id)
|
||||
if predecessor_span is not None:
|
||||
return predecessor_span
|
||||
|
||||
graph_parent_execution_id = graph_parent_index.get(execution_id)
|
||||
if graph_parent_execution_id is not None:
|
||||
graph_parent_span = span_by_execution_id.get(graph_parent_execution_id)
|
||||
if graph_parent_span is not None:
|
||||
return graph_parent_span
|
||||
|
||||
if structured_parent_execution_id is not None:
|
||||
structured_parent_span = span_by_execution_id.get(structured_parent_execution_id)
|
||||
if structured_parent_span is not None:
|
||||
return structured_parent_span
|
||||
|
||||
return workflow_span
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
@ -189,6 +521,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.propagator = TraceContextTextMapPropagator()
|
||||
self.dify_trace_ids: set[str] = set()
|
||||
self.root_span_carriers: dict[str, dict[str, str]] = {}
|
||||
self.carrier: dict[str, str] = {}
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
|
||||
@ -235,13 +569,41 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
file_list=safe_json_dumps(file_list),
|
||||
query=trace_info.query or "",
|
||||
)
|
||||
workflow_session_id = _resolve_workflow_session_id(trace_info)
|
||||
parent_workflow_run_id, parent_node_execution_id = _resolve_workflow_parent_context(trace_info)
|
||||
logger.info(
|
||||
"[Arize/Phoenix] Workflow session resolution: workflow_run_id=%s conversation_id=%s "
|
||||
"parent_workflow_run_id=%s parent_node_execution_id=%s resolved_session_id=%s",
|
||||
trace_info.workflow_run_id,
|
||||
trace_info.conversation_id,
|
||||
parent_workflow_run_id,
|
||||
parent_node_execution_id,
|
||||
workflow_session_id,
|
||||
)
|
||||
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
if parent_node_execution_id:
|
||||
workflow_parent_carrier = _resolve_published_parent_span_context(parent_node_execution_id)
|
||||
else:
|
||||
root_trace_id = _resolve_workflow_root_trace_id(trace_info)
|
||||
workflow_root_span_name: str | None = trace_info.workflow_run_id
|
||||
if not isinstance(workflow_root_span_name, str) or not workflow_root_span_name.strip():
|
||||
workflow_root_span_name = None
|
||||
|
||||
workflow_parent_carrier = self.ensure_root_span(
|
||||
root_trace_id,
|
||||
root_span_name=workflow_root_span_name,
|
||||
root_span_attributes={
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
},
|
||||
)
|
||||
|
||||
workflow_span_context = self.propagator.extract(carrier=workflow_parent_carrier)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
name=_resolve_workflow_span_name(trace_info),
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(trace_info.workflow_run_inputs),
|
||||
@ -249,10 +611,10 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(trace_info.workflow_run_outputs),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
SpanAttributes.SESSION_ID: workflow_session_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=root_span_context,
|
||||
context=workflow_span_context,
|
||||
)
|
||||
|
||||
# Through workflow_run_id, get all_nodes_execution using repository
|
||||
@ -276,16 +638,50 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution(
|
||||
workflow_execution_id=trace_info.workflow_run_id
|
||||
)
|
||||
node_title_by_id = _build_node_title_by_id(trace_info)
|
||||
execution_id_by_node_id = _build_execution_id_by_node_id(workflow_node_executions)
|
||||
graph_parent_index = _build_graph_parent_index(workflow_node_executions)
|
||||
node_execution_by_execution_id = {
|
||||
_get_node_execution_id(node_execution): node_execution for node_execution in workflow_node_executions
|
||||
}
|
||||
span_by_execution_id: dict[str, Span] = {}
|
||||
emitting_execution_ids: set[str] = set()
|
||||
|
||||
workflow_span_error: Exception | str | None = trace_info.error
|
||||
try:
|
||||
for node_execution in workflow_node_executions:
|
||||
|
||||
def emit_node_span(node_execution: _NodeExecutionLike) -> Span:
|
||||
execution_id = _get_node_execution_id(node_execution)
|
||||
existing_span = span_by_execution_id.get(execution_id)
|
||||
if existing_span is not None:
|
||||
return existing_span
|
||||
|
||||
graph_parent_execution_id = graph_parent_index.get(execution_id)
|
||||
structured_parent_execution_id = _resolve_structured_parent_execution_id(
|
||||
node_execution, execution_id_by_node_id
|
||||
)
|
||||
|
||||
if execution_id not in emitting_execution_ids:
|
||||
emitting_execution_ids.add(execution_id)
|
||||
try:
|
||||
for parent_execution_id in (graph_parent_execution_id, structured_parent_execution_id):
|
||||
if parent_execution_id is None or parent_execution_id == execution_id:
|
||||
continue
|
||||
if parent_execution_id in span_by_execution_id:
|
||||
continue
|
||||
parent_node_execution = node_execution_by_execution_id.get(parent_execution_id)
|
||||
if parent_node_execution is not None:
|
||||
emit_node_span(parent_node_execution)
|
||||
finally:
|
||||
emitting_execution_ids.discard(execution_id)
|
||||
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
inputs_value = node_execution.inputs or {}
|
||||
outputs_value = node_execution.outputs or {}
|
||||
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
elapsed_time = node_execution.elapsed_time or 0
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
@ -324,9 +720,17 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
|
||||
|
||||
workflow_span_context = set_span_in_context(workflow_span)
|
||||
parent_span = _resolve_node_parent(
|
||||
execution_id=execution_id,
|
||||
predecessor_execution_id=None,
|
||||
structured_parent_execution_id=structured_parent_execution_id,
|
||||
span_by_execution_id=span_by_execution_id,
|
||||
graph_parent_index=graph_parent_index,
|
||||
workflow_span=workflow_span,
|
||||
)
|
||||
workflow_span_context = set_span_in_context(parent_span)
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
name=_resolve_workflow_node_span_name(node_execution, node_title_by_id),
|
||||
attributes={
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
|
||||
@ -334,13 +738,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
SpanAttributes.SESSION_ID: workflow_session_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
context=workflow_span_context,
|
||||
)
|
||||
|
||||
span_by_execution_id[execution_id] = node_span
|
||||
node_span_error: Exception | str | None = None
|
||||
try:
|
||||
if node_execution.node_type == "tool":
|
||||
parent_span_carrier: dict[str, str] = {}
|
||||
with use_span(node_span, end_on_exit=False):
|
||||
self.propagator.inject(carrier=parent_span_carrier)
|
||||
_publish_parent_span_context(execution_id, parent_span_carrier)
|
||||
|
||||
if node_execution.node_type == "llm":
|
||||
llm_attributes: dict[str, Any] = {
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
@ -362,17 +773,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
)
|
||||
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
|
||||
node_span.set_attributes(llm_attributes)
|
||||
except Exception as e:
|
||||
node_span_error = e
|
||||
raise
|
||||
finally:
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_span_error is not None:
|
||||
set_span_status(node_span, node_span_error)
|
||||
elif node_execution.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
set_span_status(node_span, node_execution.error)
|
||||
else:
|
||||
set_span_status(node_span)
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
return node_span
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
emit_node_span(node_execution)
|
||||
except Exception as e:
|
||||
workflow_span_error = e
|
||||
raise
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(workflow_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(workflow_span)
|
||||
set_span_status(workflow_span, workflow_span_error)
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
@ -735,22 +1155,39 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def ensure_root_span(self, dify_trace_id: str | None):
|
||||
def ensure_root_span(
|
||||
self,
|
||||
dify_trace_id: str | None,
|
||||
*,
|
||||
root_span_name: str | None = None,
|
||||
root_span_attributes: Mapping[str, AttributeValue] | None = None,
|
||||
):
|
||||
"""Ensure a unique root span exists for the given Dify trace ID."""
|
||||
if str(dify_trace_id) not in self.dify_trace_ids:
|
||||
self.carrier: dict[str, str] = {}
|
||||
trace_key = str(dify_trace_id)
|
||||
if trace_key not in self.dify_trace_ids:
|
||||
carrier: dict[str, str] = {}
|
||||
|
||||
root_span = self.tracer.start_span(name="Dify")
|
||||
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
|
||||
root_span.set_attribute("dify_project_name", str(self.project))
|
||||
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
|
||||
span_name = root_span_name.strip() if isinstance(root_span_name, str) and root_span_name.strip() else "Dify"
|
||||
root_span_attributes_dict: dict[str, AttributeValue] = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
"dify_project_name": str(self.project),
|
||||
"dify_trace_id": trace_key,
|
||||
}
|
||||
if root_span_attributes:
|
||||
root_span_attributes_dict.update(root_span_attributes)
|
||||
|
||||
root_span = self.tracer.start_span(name=span_name, attributes=root_span_attributes_dict)
|
||||
|
||||
with use_span(root_span, end_on_exit=False):
|
||||
self.propagator.inject(carrier=self.carrier)
|
||||
self.propagator.inject(carrier=carrier)
|
||||
|
||||
set_span_status(root_span)
|
||||
root_span.end()
|
||||
self.dify_trace_ids.add(str(dify_trace_id))
|
||||
self.dify_trace_ids.add(trace_key)
|
||||
self.root_span_carriers[trace_key] = carrier
|
||||
|
||||
self.carrier = self.root_span_carriers[trace_key]
|
||||
return self.carrier
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,36 +0,0 @@
|
||||
from dify_trace_arize_phoenix.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues
|
||||
|
||||
from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes
|
||||
|
||||
|
||||
class TestGetNodeSpanKind:
|
||||
"""Tests for _get_node_span_kind helper."""
|
||||
|
||||
def test_all_node_types_are_mapped_correctly(self):
|
||||
"""Ensure every built-in node type is mapped to the correct span kind."""
|
||||
# Mappings for node types that have a specialised span kind.
|
||||
special_mappings = {
|
||||
BuiltinNodeTypes.LLM: OpenInferenceSpanKindValues.LLM,
|
||||
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER,
|
||||
BuiltinNodeTypes.TOOL: OpenInferenceSpanKindValues.TOOL,
|
||||
BuiltinNodeTypes.AGENT: OpenInferenceSpanKindValues.AGENT,
|
||||
}
|
||||
|
||||
# Test that every built-in node type is mapped to the correct span kind.
|
||||
# Node types not in `special_mappings` should default to CHAIN.
|
||||
for node_type in BUILT_IN_NODE_TYPES:
|
||||
expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN)
|
||||
actual_span_kind = _get_node_span_kind(node_type)
|
||||
assert actual_span_kind == expected_span_kind, (
|
||||
f"Node type {node_type!r} was mapped to {actual_span_kind}, but {expected_span_kind} was expected."
|
||||
)
|
||||
|
||||
def test_unknown_string_defaults_to_chain(self):
|
||||
"""An unrecognised node type string should still return CHAIN."""
|
||||
assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
def test_stale_dataset_retrieval_not_in_mapping(self):
|
||||
"""The old 'dataset_retrieval' string was never a valid NodeType value;
|
||||
make sure it is not present in the mapping dictionary."""
|
||||
assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND
|
||||
@ -65,35 +65,18 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
if field_name == "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": v,
|
||||
match field_name:
|
||||
case "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
case "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
@ -101,6 +84,29 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case _:
|
||||
pass
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
match field_name:
|
||||
case "inputs":
|
||||
data = {
|
||||
"messages": v,
|
||||
}
|
||||
case "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
case _:
|
||||
pass
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
|
||||
@ -64,7 +64,9 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
# trace_id must equal the root run's run_id (LangSmith protocol); external trace_id
|
||||
# cannot be used here as it would cause HTTP 400.
|
||||
trace_id = trace_info.message_id or trace_info.workflow_run_id
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
message_dotted_order = (
|
||||
@ -77,6 +79,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
)
|
||||
metadata = trace_info.metadata
|
||||
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
if trace_info.trace_id:
|
||||
metadata["external_trace_id"] = trace_info.trace_id
|
||||
|
||||
if trace_info.message_id:
|
||||
message_run = LangSmithRunModel(
|
||||
|
||||
@ -208,13 +208,17 @@ def test_workflow_trace(trace_instance, monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
assert call_args[0].id == "msg-1"
|
||||
assert call_args[0].name == TraceTaskName.MESSAGE_TRACE
|
||||
# trace_id must equal root run's id (message_id), not the external trace_id "trace-1"
|
||||
assert call_args[0].trace_id == "msg-1"
|
||||
|
||||
assert call_args[1].id == "run-1"
|
||||
assert call_args[1].name == TraceTaskName.WORKFLOW_TRACE
|
||||
assert call_args[1].parent_run_id == "msg-1"
|
||||
assert call_args[1].trace_id == "msg-1"
|
||||
|
||||
assert call_args[2].id == "node-llm"
|
||||
assert call_args[2].run_type == LangSmithRunType.llm
|
||||
assert call_args[2].trace_id == "msg-1"
|
||||
|
||||
assert call_args[3].id == "node-other"
|
||||
assert call_args[3].run_type == LangSmithRunType.tool
|
||||
@ -604,3 +608,83 @@ def test_get_project_url_error(trace_instance):
|
||||
trace_instance.langsmith_client.get_run_url.side_effect = Exception("error")
|
||||
with pytest.raises(ValueError, match="LangSmith get run url failed: error"):
|
||||
trace_instance.get_project_url()
|
||||
|
||||
|
||||
def _make_workflow_trace_info(
|
||||
*, message_id: str | None, workflow_run_id: str, trace_id: str | None
|
||||
) -> WorkflowTraceInfo:
|
||||
workflow_data = MagicMock()
|
||||
workflow_data.created_at = _dt()
|
||||
workflow_data.finished_at = _dt() + timedelta(seconds=1)
|
||||
return WorkflowTraceInfo(
|
||||
tenant_id="tenant-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_inputs={},
|
||||
workflow_run_outputs={},
|
||||
workflow_run_status="succeeded",
|
||||
workflow_run_version="1.0",
|
||||
workflow_run_elapsed_time=1.0,
|
||||
total_tokens=0,
|
||||
file_list=[],
|
||||
query="q",
|
||||
message_id=message_id,
|
||||
conversation_id="conv-1" if message_id else None,
|
||||
start_time=_dt(),
|
||||
end_time=_dt() + timedelta(seconds=1),
|
||||
trace_id=trace_id,
|
||||
metadata={"app_id": "app-1"},
|
||||
workflow_app_log_id=None,
|
||||
error=None,
|
||||
workflow_data=workflow_data,
|
||||
)
|
||||
|
||||
|
||||
def _patch_workflow_trace_deps(monkeypatch, trace_instance):
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock())
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.db", MagicMock(engine="engine"))
|
||||
repo = MagicMock()
|
||||
repo.get_by_workflow_execution.return_value = []
|
||||
factory = MagicMock()
|
||||
factory.create_workflow_node_execution_repository.return_value = repo
|
||||
monkeypatch.setattr("dify_trace_langsmith.langsmith_trace.DifyCoreRepositoryFactory", factory)
|
||||
monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock())
|
||||
trace_instance.add_run = MagicMock()
|
||||
|
||||
|
||||
def test_workflow_trace_id_uses_message_id_not_external(trace_instance, monkeypatch):
|
||||
"""Chatflow with external trace_id: LangSmith trace_id must be message_id, not external."""
|
||||
trace_info = _make_workflow_trace_info(
|
||||
message_id="msg-abc",
|
||||
workflow_run_id="run-xyz",
|
||||
trace_id="external-999",
|
||||
)
|
||||
_patch_workflow_trace_deps(monkeypatch, trace_instance)
|
||||
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
calls = [c[0][0] for c in trace_instance.add_run.call_args_list]
|
||||
# message run (root) and workflow run (child) must both use message_id as trace_id
|
||||
assert calls[0].id == "msg-abc"
|
||||
assert calls[0].trace_id == "msg-abc"
|
||||
assert calls[1].id == "run-xyz"
|
||||
assert calls[1].trace_id == "msg-abc"
|
||||
# external_trace_id preserved in metadata
|
||||
assert trace_info.metadata.get("external_trace_id") == "external-999"
|
||||
|
||||
|
||||
def test_workflow_trace_id_pure_workflow_uses_run_id(trace_instance, monkeypatch):
|
||||
"""Pure workflow (no message_id) with external trace_id: trace_id must be workflow_run_id."""
|
||||
trace_info = _make_workflow_trace_info(
|
||||
message_id=None,
|
||||
workflow_run_id="run-xyz",
|
||||
trace_id="external-999",
|
||||
)
|
||||
_patch_workflow_trace_deps(monkeypatch, trace_instance)
|
||||
|
||||
trace_instance.workflow_trace(trace_info)
|
||||
|
||||
calls = [c[0][0] for c in trace_instance.add_run.call_args_list]
|
||||
# workflow run is the root; trace_id must equal its run_id
|
||||
assert calls[0].id == "run-xyz"
|
||||
assert calls[0].trace_id == "run-xyz"
|
||||
|
||||
@ -81,14 +81,15 @@ class OpenSearchConfig(BaseModel):
|
||||
pool_maxsize=20,
|
||||
)
|
||||
|
||||
if self.auth_method == "basic":
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
match self.auth_method:
|
||||
case AuthMethod.BASIC:
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
|
||||
params["http_auth"] = (self.user, self.password)
|
||||
elif self.auth_method == "aws_managed_iam":
|
||||
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
|
||||
params["http_auth"] = (self.user, self.password)
|
||||
case AuthMethod.AWS_MANAGED_IAM:
|
||||
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
|
||||
|
||||
params["http_auth"] = self.create_aws_managed_iam_auth()
|
||||
params["http_auth"] = self.create_aws_managed_iam_auth()
|
||||
|
||||
return params
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.14.0"
|
||||
version = "1.14.1"
|
||||
requires-python = "~=3.12.0"
|
||||
|
||||
dependencies = [
|
||||
# Legacy: mature and widely deployed
|
||||
"bleach>=6.3.0",
|
||||
"boto3>=1.43.3",
|
||||
"boto3>=1.43.6",
|
||||
"celery>=5.6.3",
|
||||
"croniter>=6.2.2",
|
||||
"flask>=3.1.3,<4.0.0",
|
||||
@ -14,8 +14,8 @@ dependencies = [
|
||||
"gevent>=26.4.0",
|
||||
"gevent-websocket>=0.10.1",
|
||||
"gmpy2>=2.3.0",
|
||||
"google-api-python-client>=2.195.0",
|
||||
"gunicorn>=25.3.0",
|
||||
"google-api-python-client>=2.196.0",
|
||||
"gunicorn>=26.0.0",
|
||||
"psycogreen>=1.0.2",
|
||||
"psycopg2-binary>=2.9.12",
|
||||
"python-socketio>=5.13.0",
|
||||
@ -31,7 +31,7 @@ dependencies = [
|
||||
"flask-migrate>=4.1.0,<5.0.0",
|
||||
"flask-orjson>=2.0.0,<3.0.0",
|
||||
"flask-restx>=1.3.2,<2.0.0",
|
||||
"google-cloud-aiplatform>=1.149.0,<2.0.0",
|
||||
"google-cloud-aiplatform>=1.151.0,<2.0.0",
|
||||
"httpx[socks]>=0.28.1,<1.0.0",
|
||||
"opentelemetry-distro>=0.62b1,<1.0.0",
|
||||
"opentelemetry-instrumentation-celery>=0.62b0,<1.0.0",
|
||||
@ -45,7 +45,7 @@ dependencies = [
|
||||
|
||||
# Emerging: newer and fast-moving, use compatible pins
|
||||
"fastopenapi[flask]~=0.7.0",
|
||||
"graphon~=0.3.0",
|
||||
"graphon~=0.3.1",
|
||||
"httpx-sse~=0.4.0",
|
||||
"json-repair~=0.59.4",
|
||||
]
|
||||
@ -191,7 +191,7 @@ storage = [
|
||||
"google-cloud-storage>=3.10.1",
|
||||
"opendal>=0.46.0",
|
||||
"oss2>=2.19.1",
|
||||
"supabase>=2.29.0",
|
||||
"supabase>=2.30.0",
|
||||
"tos>=2.9.0",
|
||||
]
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
@ -31,39 +32,59 @@ from tasks.remove_app_and_related_data_task import remove_app_and_related_data_t
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppListParams(BaseModel):
|
||||
page: int = Field(default=1, ge=1)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = "all"
|
||||
name: str | None = None
|
||||
tag_ids: list[str] | None = None
|
||||
is_created_by_me: bool | None = None
|
||||
|
||||
|
||||
class CreateAppParams(BaseModel):
|
||||
name: str = Field(min_length=1)
|
||||
description: str | None = None
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
api_rph: int = 0
|
||||
api_rpm: int = 0
|
||||
max_active_requests: int | None = None
|
||||
|
||||
|
||||
class AppService:
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict[str, Any]) -> Pagination | None:
|
||||
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
|
||||
"""
|
||||
Get app list with pagination
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id
|
||||
:param args: request args
|
||||
:param params: query parameters
|
||||
:return:
|
||||
"""
|
||||
filters = [App.tenant_id == tenant_id, App.is_universal == False]
|
||||
|
||||
if args["mode"] == "workflow":
|
||||
if params.mode == "workflow":
|
||||
filters.append(App.mode == AppMode.WORKFLOW)
|
||||
elif args["mode"] == "completion":
|
||||
elif params.mode == "completion":
|
||||
filters.append(App.mode == AppMode.COMPLETION)
|
||||
elif args["mode"] == "chat":
|
||||
elif params.mode == "chat":
|
||||
filters.append(App.mode == AppMode.CHAT)
|
||||
elif args["mode"] == "advanced-chat":
|
||||
elif params.mode == "advanced-chat":
|
||||
filters.append(App.mode == AppMode.ADVANCED_CHAT)
|
||||
elif args["mode"] == "agent-chat":
|
||||
elif params.mode == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT)
|
||||
|
||||
if args.get("is_created_by_me", False):
|
||||
if params.is_created_by_me:
|
||||
filters.append(App.created_by == user_id)
|
||||
if args.get("name"):
|
||||
if params.name:
|
||||
from libs.helper import escape_like_pattern
|
||||
|
||||
name = args["name"][:30]
|
||||
name = params.name[:30]
|
||||
escaped_name = escape_like_pattern(name)
|
||||
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
|
||||
if params.tag_ids and len(params.tag_ids) > 0:
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, params.tag_ids)
|
||||
if target_ids and len(target_ids) > 0:
|
||||
filters.append(App.id.in_(target_ids))
|
||||
else:
|
||||
@ -71,21 +92,21 @@ class AppService:
|
||||
|
||||
app_models = db.paginate(
|
||||
sa.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
page=args["page"],
|
||||
per_page=args["limit"],
|
||||
page=params.page,
|
||||
per_page=params.limit,
|
||||
error_out=False,
|
||||
)
|
||||
|
||||
return app_models
|
||||
|
||||
def create_app(self, tenant_id: str, args: dict[str, Any], account: Account) -> App:
|
||||
def create_app(self, tenant_id: str, params: CreateAppParams, account: Account) -> App:
|
||||
"""
|
||||
Create app
|
||||
:param tenant_id: tenant id
|
||||
:param args: request args
|
||||
:param params: app creation parameters
|
||||
:param account: Account instance
|
||||
"""
|
||||
app_mode = AppMode.value_of(args["mode"])
|
||||
app_mode = AppMode.value_of(params.mode)
|
||||
app_template = default_app_templates[app_mode]
|
||||
|
||||
# get model config
|
||||
@ -143,15 +164,16 @@ class AppService:
|
||||
default_model_config["model"] = json.dumps(default_model_dict)
|
||||
|
||||
app = App(**app_template["app"])
|
||||
app.name = args["name"]
|
||||
app.description = args.get("description", "")
|
||||
app.mode = args["mode"]
|
||||
app.icon_type = args.get("icon_type", "emoji")
|
||||
app.icon = args["icon"]
|
||||
app.icon_background = args["icon_background"]
|
||||
app.name = params.name
|
||||
app.description = params.description or ""
|
||||
app.mode = app_mode
|
||||
app.icon_type = IconType(params.icon_type) if params.icon_type else IconType.EMOJI
|
||||
app.icon = params.icon
|
||||
app.icon_background = params.icon_background
|
||||
app.tenant_id = tenant_id
|
||||
app.api_rph = args.get("api_rph", 0)
|
||||
app.api_rpm = args.get("api_rpm", 0)
|
||||
app.api_rph = params.api_rph
|
||||
app.api_rpm = params.api_rpm
|
||||
app.max_active_requests = params.max_active_requests
|
||||
app.created_by = account.id
|
||||
app.updated_by = account.id
|
||||
|
||||
|
||||
@ -7,9 +7,10 @@ import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
from typing import Annotated, Any, Literal, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import delete, exists, func, select, update
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -108,7 +109,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProcessRulesDict(TypedDict):
|
||||
mode: str
|
||||
mode: ProcessRuleMode
|
||||
rules: dict[str, Any]
|
||||
|
||||
|
||||
@ -117,6 +118,86 @@ class AutoDisableLogsDict(TypedDict):
|
||||
count: int
|
||||
|
||||
|
||||
class _EstimatePreProcessingRule(BaseModel):
|
||||
id: str = Field(min_length=1)
|
||||
enabled: bool
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def _validate_id(cls, v: str) -> str:
|
||||
if v not in DatasetProcessRule.PRE_PROCESSING_RULES:
|
||||
raise ValueError("Process rule pre_processing_rules id is invalid")
|
||||
return v
|
||||
|
||||
|
||||
class _EstimateSegmentation(BaseModel):
|
||||
separator: str = Field(min_length=1)
|
||||
max_tokens: int = Field(gt=0)
|
||||
|
||||
|
||||
class _EstimateRules(BaseModel):
|
||||
pre_processing_rules: list[_EstimatePreProcessingRule]
|
||||
segmentation: _EstimateSegmentation
|
||||
|
||||
@field_validator("pre_processing_rules")
|
||||
@classmethod
|
||||
def _deduplicate(cls, v: list[_EstimatePreProcessingRule]) -> list[_EstimatePreProcessingRule]:
|
||||
seen: dict[str, _EstimatePreProcessingRule] = {}
|
||||
for rule in v:
|
||||
seen[rule.id] = rule
|
||||
return list(seen.values())
|
||||
|
||||
|
||||
class _SummaryIndexSettingDisabled(BaseModel):
|
||||
enable: Literal[False] = False
|
||||
|
||||
|
||||
class _SummaryIndexSettingEnabled(BaseModel):
|
||||
enable: Literal[True]
|
||||
model_name: str = Field(min_length=1)
|
||||
model_provider_name: str = Field(min_length=1)
|
||||
|
||||
|
||||
_SummaryIndexSetting = Annotated[
|
||||
_SummaryIndexSettingDisabled | _SummaryIndexSettingEnabled,
|
||||
Field(discriminator="enable"),
|
||||
]
|
||||
|
||||
|
||||
class _AutomaticProcessRule(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
mode: Literal[ProcessRuleMode.AUTOMATIC]
|
||||
summary_index_setting: _SummaryIndexSetting | None = None
|
||||
|
||||
|
||||
class _CustomProcessRule(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
mode: Literal[ProcessRuleMode.CUSTOM]
|
||||
rules: _EstimateRules
|
||||
summary_index_setting: _SummaryIndexSetting | None = None
|
||||
|
||||
|
||||
class _HierarchicalProcessRule(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
mode: Literal[ProcessRuleMode.HIERARCHICAL]
|
||||
rules: _EstimateRules
|
||||
summary_index_setting: _SummaryIndexSetting | None = None
|
||||
|
||||
|
||||
_EstimateProcessRule = Annotated[
|
||||
_AutomaticProcessRule | _CustomProcessRule | _HierarchicalProcessRule,
|
||||
Field(discriminator="mode"),
|
||||
]
|
||||
|
||||
|
||||
class _EstimateArgs(BaseModel):
|
||||
info_list: dict[str, Any]
|
||||
process_rule: _EstimateProcessRule
|
||||
|
||||
|
||||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
|
||||
@ -204,7 +285,7 @@ class DatasetService:
|
||||
mode = dataset_process_rule.mode
|
||||
rules = dataset_process_rule.rules_dict or {}
|
||||
else:
|
||||
mode = str(DocumentService.DEFAULT_RULES["mode"])
|
||||
mode = ProcessRuleMode(DocumentService.DEFAULT_RULES["mode"])
|
||||
rules = dict(DocumentService.DEFAULT_RULES.get("rules") or {})
|
||||
return {"mode": mode, "rules": rules}
|
||||
|
||||
@ -1984,7 +2065,7 @@ class DocumentService:
|
||||
if process_rule.rules:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
mode=ProcessRuleMode(process_rule.mode),
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -1995,7 +2076,7 @@ class DocumentService:
|
||||
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
mode=ProcessRuleMode.AUTOMATIC,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -2572,14 +2653,14 @@ class DocumentService:
|
||||
if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
mode=ProcessRuleMode(process_rule.mode),
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
mode=ProcessRuleMode.AUTOMATIC,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
@ -2851,94 +2932,16 @@ class DocumentService:
|
||||
|
||||
@classmethod
|
||||
def estimate_args_validate(cls, args: dict[str, Any]):
|
||||
if "info_list" not in args or not args["info_list"]:
|
||||
raise ValueError("Data source info is required")
|
||||
|
||||
if not isinstance(args["info_list"], dict):
|
||||
raise ValueError("Data info is invalid")
|
||||
|
||||
if "process_rule" not in args or not args["process_rule"]:
|
||||
raise ValueError("Process rule is required")
|
||||
|
||||
if not isinstance(args["process_rule"], dict):
|
||||
raise ValueError("Process rule is invalid")
|
||||
|
||||
if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]:
|
||||
raise ValueError("Process rule mode is required")
|
||||
|
||||
if args["process_rule"]["mode"] not in DatasetProcessRule.MODES:
|
||||
raise ValueError("Process rule mode is invalid")
|
||||
|
||||
if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC:
|
||||
args["process_rule"]["rules"] = {}
|
||||
else:
|
||||
if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]:
|
||||
raise ValueError("Process rule rules is required")
|
||||
|
||||
if not isinstance(args["process_rule"]["rules"], dict):
|
||||
raise ValueError("Process rule rules is invalid")
|
||||
|
||||
if (
|
||||
"pre_processing_rules" not in args["process_rule"]["rules"]
|
||||
or args["process_rule"]["rules"]["pre_processing_rules"] is None
|
||||
):
|
||||
raise ValueError("Process rule pre_processing_rules is required")
|
||||
|
||||
if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list):
|
||||
raise ValueError("Process rule pre_processing_rules is invalid")
|
||||
|
||||
unique_pre_processing_rule_dicts = {}
|
||||
for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]:
|
||||
if "id" not in pre_processing_rule or not pre_processing_rule["id"]:
|
||||
raise ValueError("Process rule pre_processing_rules id is required")
|
||||
|
||||
if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES:
|
||||
raise ValueError("Process rule pre_processing_rules id is invalid")
|
||||
|
||||
if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None:
|
||||
raise ValueError("Process rule pre_processing_rules enabled is required")
|
||||
|
||||
if not isinstance(pre_processing_rule["enabled"], bool):
|
||||
raise ValueError("Process rule pre_processing_rules enabled is invalid")
|
||||
|
||||
unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule
|
||||
|
||||
args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values())
|
||||
|
||||
if (
|
||||
"segmentation" not in args["process_rule"]["rules"]
|
||||
or args["process_rule"]["rules"]["segmentation"] is None
|
||||
):
|
||||
raise ValueError("Process rule segmentation is required")
|
||||
|
||||
if not isinstance(args["process_rule"]["rules"]["segmentation"], dict):
|
||||
raise ValueError("Process rule segmentation is invalid")
|
||||
|
||||
if (
|
||||
"separator" not in args["process_rule"]["rules"]["segmentation"]
|
||||
or not args["process_rule"]["rules"]["segmentation"]["separator"]
|
||||
):
|
||||
raise ValueError("Process rule segmentation separator is required")
|
||||
|
||||
if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str):
|
||||
raise ValueError("Process rule segmentation separator is invalid")
|
||||
|
||||
if (
|
||||
"max_tokens" not in args["process_rule"]["rules"]["segmentation"]
|
||||
or not args["process_rule"]["rules"]["segmentation"]["max_tokens"]
|
||||
):
|
||||
raise ValueError("Process rule segmentation max_tokens is required")
|
||||
|
||||
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
# valid summary index setting
|
||||
summary_index_setting = args["process_rule"].get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
|
||||
raise ValueError("Summary index model name is required")
|
||||
if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
|
||||
raise ValueError("Summary index model provider name is required")
|
||||
try:
|
||||
validated = _EstimateArgs.model_validate(args)
|
||||
except ValidationError as e:
|
||||
first = e.errors()[0]
|
||||
original = first.get("ctx", {}).get("error")
|
||||
raise ValueError(str(original) if isinstance(original, ValueError) else first["msg"]) from e
|
||||
process_rule_dict = validated.process_rule.model_dump(exclude_none=True)
|
||||
if validated.process_rule.mode == ProcessRuleMode.AUTOMATIC:
|
||||
process_rule_dict["rules"] = {}
|
||||
args["process_rule"] = process_rule_dict
|
||||
|
||||
@staticmethod
|
||||
def batch_update_document_status(
|
||||
|
||||
@ -166,7 +166,7 @@ class SystemFeatureModel(BaseModel):
|
||||
enable_email_code_login: bool = False
|
||||
enable_email_password_login: bool = True
|
||||
enable_social_oauth_login: bool = False
|
||||
enable_collaboration_mode: bool = False
|
||||
enable_collaboration_mode: bool = True
|
||||
is_allow_register: bool = False
|
||||
is_allow_create_workspace: bool = False
|
||||
is_email_setup: bool = False
|
||||
|
||||
@ -111,6 +111,7 @@ class VectorService:
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
assert segment.index_node_id
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
# update vector index
|
||||
vector = Vector(dataset=dataset)
|
||||
@ -138,6 +139,7 @@ class VectorService:
|
||||
regenerate: bool = False,
|
||||
):
|
||||
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
|
||||
assert segment.index_node_id
|
||||
if regenerate:
|
||||
# delete child chunks
|
||||
index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
@ -1066,8 +1066,13 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
rendered_content = node.render_form_content_before_submission()
|
||||
selected_action = next(
|
||||
(user_action for user_action in node_data.user_actions if user_action.id == action),
|
||||
None,
|
||||
)
|
||||
outputs: dict[str, Any] = dict(form_inputs)
|
||||
outputs["__action_id"] = action
|
||||
outputs["__action_value"] = selected_action.title if selected_action else ""
|
||||
outputs["__rendered_content"] = node.render_form_content_with_outputs(
|
||||
rendered_content, outputs, node_data.outputs_field_names()
|
||||
)
|
||||
|
||||
@ -50,7 +50,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
||||
).all()
|
||||
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
# Collect image file IDs from segment content
|
||||
|
||||
@ -19,6 +19,7 @@ from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.enums import SegmentStatus
|
||||
from models.model import UploadFile
|
||||
from services.vector_service import VectorService
|
||||
|
||||
@ -156,7 +157,7 @@ def batch_create_segment_to_index_task(
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=naive_utc_now(),
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
if document_config["doc_form"] == IndexStructureType.QA_INDEX:
|
||||
|
||||
@ -53,7 +53,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
|
||||
segment_contents = [segment.content for segment in segments]
|
||||
except Exception:
|
||||
logger.exception("Cleaned document when document deleted failed")
|
||||
|
||||
@ -38,7 +38,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
||||
|
||||
for document_id in document_ids:
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
total_index_node_ids.extend([segment.index_node_id for segment in segments])
|
||||
total_index_node_ids.extend([segment.index_node_id for segment in segments if segment.index_node_id])
|
||||
|
||||
# Wrap vector / keyword index cleanup in try/except so that a transient
|
||||
# failure here (e.g. billing API hiccup propagated via FeatureService when
|
||||
|
||||
@ -9,6 +9,7 @@ from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DocumentSegment
|
||||
from models.enums import SegmentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -30,7 +31,7 @@ def disable_segment_from_index_task(segment_id: str):
|
||||
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
if segment.status != "completed":
|
||||
if segment.status != SegmentStatus.COMPLETED:
|
||||
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
|
||||
return
|
||||
|
||||
@ -59,6 +60,7 @@ def disable_segment_from_index_task(segment_id: str):
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
assert segment.index_node_id
|
||||
index_processor.clean(dataset, [segment.index_node_id])
|
||||
|
||||
# Disable summary index for this segment
|
||||
|
||||
@ -55,7 +55,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
||||
return
|
||||
|
||||
try:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
|
||||
if dataset.is_multimodal:
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_attachment_bindings = session.scalars(
|
||||
|
||||
@ -69,7 +69,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
index_type = document.doc_form
|
||||
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
|
||||
|
||||
# Get credentials from datasource provider
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
@ -45,7 +45,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
|
||||
index_type = document.doc_form
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
|
||||
|
||||
clean_success = False
|
||||
try:
|
||||
|
||||
@ -137,7 +137,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document.id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
@ -1,11 +1,31 @@
|
||||
"""
|
||||
Celery task for asynchronous ops trace dispatch.
|
||||
|
||||
Trace providers may report explicitly retryable dispatch failures through the
|
||||
core retryable exception contract. The task preserves the payload file only
|
||||
when Celery accepts the retry request; successful dispatches and terminal
|
||||
failures clean up the stored payload.
|
||||
|
||||
One concrete producer today is Phoenix nested workflow tracing. The outer
|
||||
workflow tool span publishes a restorable parent span context asynchronously,
|
||||
while the nested workflow trace may be picked up by Celery first. In that
|
||||
ordering window, the provider raises a retryable core exception instead of
|
||||
dropping the trace or emitting it under the wrong parent. The task intentionally
|
||||
does not know that the provider is Phoenix; it only honors the core retryable
|
||||
dispatch contract.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from celery.exceptions import Retry
|
||||
from flask import current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
|
||||
from core.ops.entities.trace_entity import trace_info_info_map
|
||||
from core.ops.exceptions import RetryableTraceDispatchError
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -14,9 +34,17 @@ from models.workflow import WorkflowRun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RETRYABLE_TRACE_DISPATCH_LIMIT = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_MAX_RETRIES
|
||||
_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS = dify_config.OPS_TRACE_RETRYABLE_DISPATCH_DELAY_SECONDS
|
||||
|
||||
@shared_task(queue="ops_trace")
|
||||
def process_trace_tasks(file_info):
|
||||
|
||||
@shared_task(
|
||||
queue="ops_trace",
|
||||
bind=True,
|
||||
max_retries=_RETRYABLE_TRACE_DISPATCH_LIMIT,
|
||||
default_retry_delay=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS,
|
||||
)
|
||||
def process_trace_tasks(self, file_info):
|
||||
"""
|
||||
Async process trace tasks
|
||||
Usage: process_trace_tasks.delay(tasks_data)
|
||||
@ -29,6 +57,7 @@ def process_trace_tasks(file_info):
|
||||
file_data = json.loads(storage.load(file_path))
|
||||
trace_info = file_data.get("trace_info")
|
||||
trace_info_type = file_data.get("trace_info_type")
|
||||
enterprise_trace_dispatched = bool(file_data.get("_enterprise_trace_dispatched"))
|
||||
trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
|
||||
|
||||
if trace_info.get("message_data"):
|
||||
@ -38,6 +67,8 @@ def process_trace_tasks(file_info):
|
||||
if trace_info.get("documents"):
|
||||
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
|
||||
|
||||
should_delete_file = True
|
||||
|
||||
try:
|
||||
trace_type = trace_info_info_map.get(trace_info_type)
|
||||
if trace_type:
|
||||
@ -45,30 +76,66 @@ def process_trace_tasks(file_info):
|
||||
|
||||
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
|
||||
|
||||
if is_ee_telemetry_enabled():
|
||||
if is_ee_telemetry_enabled() and not enterprise_trace_dispatched:
|
||||
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
|
||||
|
||||
try:
|
||||
EnterpriseOtelTrace().trace(trace_info)
|
||||
except Exception:
|
||||
logger.exception("Enterprise trace failed for app_id: %s", app_id)
|
||||
else:
|
||||
file_data["_enterprise_trace_dispatched"] = True
|
||||
enterprise_trace_dispatched = True
|
||||
|
||||
if trace_instance:
|
||||
with current_app.app_context():
|
||||
trace_instance.trace(trace_info)
|
||||
|
||||
logger.info("Processing trace tasks success, app_id: %s", app_id)
|
||||
except RetryableTraceDispatchError as e:
|
||||
# Retryable dispatch failures represent a transient provider-side
|
||||
# ordering gap, not corrupt payload data. Keep the payload only after
|
||||
# Celery accepts the retry request; otherwise this attempt becomes a
|
||||
# terminal failure and the stored file is cleaned up in `finally`.
|
||||
#
|
||||
# Enterprise telemetry runs before provider dispatch. If it already ran
|
||||
# and provider dispatch asks for a retry, persist that private flag so
|
||||
# the next attempt does not emit the same enterprise trace twice.
|
||||
if self.request.retries >= _RETRYABLE_TRACE_DISPATCH_LIMIT:
|
||||
logger.exception("Retryable trace dispatch budget exhausted, app_id: %s", app_id)
|
||||
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
|
||||
redis_client.incr(failed_key)
|
||||
else:
|
||||
logger.warning(
|
||||
"Retryable trace dispatch failure, scheduling retry %s/%s for app_id %s: %s",
|
||||
self.request.retries + 1,
|
||||
_RETRYABLE_TRACE_DISPATCH_LIMIT,
|
||||
app_id,
|
||||
e,
|
||||
)
|
||||
try:
|
||||
if enterprise_trace_dispatched:
|
||||
storage.save(file_path, json.dumps(file_data).encode("utf-8"))
|
||||
raise self.retry(exc=e, countdown=_RETRYABLE_TRACE_DISPATCH_DELAY_SECONDS)
|
||||
except Retry:
|
||||
should_delete_file = False
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to schedule trace dispatch retry, app_id: %s", app_id)
|
||||
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
|
||||
redis_client.incr(failed_key)
|
||||
except Exception as e:
|
||||
logger.exception("Processing trace tasks failed, app_id: %s", app_id)
|
||||
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
|
||||
redis_client.incr(failed_key)
|
||||
finally:
|
||||
try:
|
||||
storage.delete(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete trace file %s for app_id %s: %s",
|
||||
file_path,
|
||||
app_id,
|
||||
e,
|
||||
)
|
||||
if should_delete_file:
|
||||
try:
|
||||
storage.delete(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete trace file %s for app_id %s: %s",
|
||||
file_path,
|
||||
app_id,
|
||||
e,
|
||||
)
|
||||
|
||||
@ -61,7 +61,7 @@ def remove_document_from_index_task(document_id: str):
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e))
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
|
||||
if index_node_ids:
|
||||
try:
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
|
||||
|
||||
@ -85,7 +85,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
|
||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
index_node_ids = [segment.index_node_id for segment in segments if segment.index_node_id]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
|
||||
@ -91,7 +91,11 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
return node
|
||||
|
||||
|
||||
def test_execute_llm():
|
||||
def _mock_db_session_close(monkeypatch) -> None:
|
||||
monkeypatch.setattr(db.session, "close", MagicMock())
|
||||
|
||||
|
||||
def test_execute_llm(monkeypatch):
|
||||
node = init_llm_node(
|
||||
config={
|
||||
"id": "llm",
|
||||
@ -118,7 +122,7 @@ def test_execute_llm():
|
||||
},
|
||||
)
|
||||
|
||||
db.session.close = MagicMock()
|
||||
_mock_db_session_close(monkeypatch)
|
||||
|
||||
def build_mock_model_instance() -> MagicMock:
|
||||
from decimal import Decimal
|
||||
@ -195,7 +199,7 @@ def test_execute_llm():
|
||||
assert item.node_run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
|
||||
|
||||
def test_execute_llm_with_jinja2():
|
||||
def test_execute_llm_with_jinja2(monkeypatch):
|
||||
"""
|
||||
Test execute LLM node with jinja2
|
||||
"""
|
||||
@ -233,8 +237,7 @@ def test_execute_llm_with_jinja2():
|
||||
},
|
||||
)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
_mock_db_session_close(monkeypatch)
|
||||
|
||||
def build_mock_model_instance() -> MagicMock:
|
||||
from decimal import Decimal
|
||||
|
||||
@ -83,7 +83,11 @@ def init_parameter_extractor_node(config: dict, memory=None):
|
||||
return node
|
||||
|
||||
|
||||
def test_function_calling_parameter_extractor(setup_model_mock):
|
||||
def _mock_db_session_close(monkeypatch) -> None:
|
||||
monkeypatch.setattr(db.session, "close", MagicMock())
|
||||
|
||||
|
||||
def test_function_calling_parameter_extractor(setup_model_mock, monkeypatch):
|
||||
"""
|
||||
Test function calling for parameter extractor.
|
||||
"""
|
||||
@ -114,7 +118,7 @@ def test_function_calling_parameter_extractor(setup_model_mock):
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
_mock_db_session_close(monkeypatch)
|
||||
|
||||
result = node._run()
|
||||
|
||||
@ -124,7 +128,7 @@ def test_function_calling_parameter_extractor(setup_model_mock):
|
||||
assert result.outputs.get("__reason") == None
|
||||
|
||||
|
||||
def test_instructions(setup_model_mock):
|
||||
def test_instructions(setup_model_mock, monkeypatch):
|
||||
"""
|
||||
Test chat parameter extractor.
|
||||
"""
|
||||
@ -155,7 +159,7 @@ def test_instructions(setup_model_mock):
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
_mock_db_session_close(monkeypatch)
|
||||
|
||||
result = node._run()
|
||||
|
||||
@ -174,7 +178,7 @@ def test_instructions(setup_model_mock):
|
||||
assert "what's the weather in SF" in prompt.get("text")
|
||||
|
||||
|
||||
def test_chat_parameter_extractor(setup_model_mock):
|
||||
def test_chat_parameter_extractor(setup_model_mock, monkeypatch):
|
||||
"""
|
||||
Test chat parameter extractor.
|
||||
"""
|
||||
@ -205,7 +209,7 @@ def test_chat_parameter_extractor(setup_model_mock):
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
_mock_db_session_close(monkeypatch)
|
||||
|
||||
result = node._run()
|
||||
|
||||
@ -225,7 +229,7 @@ def test_chat_parameter_extractor(setup_model_mock):
|
||||
assert '<structure>\n{"type": "object"' in prompt.get("text")
|
||||
|
||||
|
||||
def test_completion_parameter_extractor(setup_model_mock):
|
||||
def test_completion_parameter_extractor(setup_model_mock, monkeypatch):
|
||||
"""
|
||||
Test completion parameter extractor.
|
||||
"""
|
||||
@ -256,7 +260,7 @@ def test_completion_parameter_extractor(setup_model_mock):
|
||||
mode="completion",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
_mock_db_session_close(monkeypatch)
|
||||
|
||||
result = node._run()
|
||||
|
||||
@ -350,7 +354,7 @@ def test_extract_json_from_tool_call():
|
||||
assert result["location"] == "kawaii"
|
||||
|
||||
|
||||
def test_chat_parameter_extractor_with_memory(setup_model_mock):
|
||||
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
|
||||
"""
|
||||
Test chat parameter extractor with memory.
|
||||
"""
|
||||
@ -382,7 +386,7 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock):
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
_mock_db_session_close(monkeypatch)
|
||||
|
||||
result = node._run()
|
||||
|
||||
|
||||
@ -168,6 +168,7 @@ def test_node_variable_collection_get_success(
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
node_variable = _create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_123")
|
||||
node_variable_id = node_variable.id
|
||||
_create_node_variable(db_session_with_containers, app.id, account.id, node_id="node_456", name="other")
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
@ -178,7 +179,7 @@ def test_node_variable_collection_get_success(
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert [item["id"] for item in payload["items"]] == [node_variable.id]
|
||||
assert [item["id"] for item in payload["items"]] == [node_variable_id]
|
||||
|
||||
|
||||
def test_node_variable_collection_get_invalid_node_id(
|
||||
@ -377,6 +378,7 @@ def test_system_variable_collection_get(
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
app = create_console_app(db_session_with_containers, tenant.id, account.id, AppMode.WORKFLOW)
|
||||
variable = _create_system_variable(db_session_with_containers, app.id, account.id)
|
||||
variable_id = variable.id
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
f"/console/api/apps/{app.id}/workflows/draft/system-variables",
|
||||
@ -386,7 +388,7 @@ def test_system_variable_collection_get(
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload is not None
|
||||
assert [item["id"] for item in payload["items"]] == [variable.id]
|
||||
assert [item["id"] for item in payload["items"]] == [variable_id]
|
||||
|
||||
|
||||
def test_environment_variable_collection_get(
|
||||
|
||||
@ -17,6 +17,8 @@ def test_get_oauth_url_successful(
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
tenant_id = tenant.id
|
||||
current_tenant_id = account.current_tenant_id
|
||||
provider = MagicMock()
|
||||
provider.get_authorization_url.return_value = "http://oauth.provider/auth"
|
||||
|
||||
@ -29,7 +31,7 @@ def test_get_oauth_url_successful(
|
||||
headers=authenticate_console_client(test_client_with_containers, account),
|
||||
)
|
||||
|
||||
assert tenant.id == account.current_tenant_id
|
||||
assert tenant_id == current_tenant_id
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"data": "http://oauth.provider/auth"}
|
||||
provider.get_authorization_url.assert_called_once()
|
||||
|
||||
@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
@ -20,13 +21,15 @@ from controllers.console.auth.forgot_password import (
|
||||
ForgotPasswordSendEmailApi,
|
||||
)
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import ensure_dify_setup
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
"""Test cases for sending password reset emails."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask, db_session_with_containers: Session):
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
@ -139,7 +142,8 @@ class TestForgotPasswordCheckApi:
|
||||
"""Test cases for verifying password reset codes."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask, db_session_with_containers: Session):
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit")
|
||||
@ -322,7 +326,8 @@ class TestForgotPasswordResetApi:
|
||||
"""Test cases for resetting password with verified token."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask, db_session_with_containers: Session):
|
||||
ensure_dify_setup(db_session_with_containers)
|
||||
return flask_app_with_containers
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -11,7 +11,7 @@ from models.enums import ConversationFromSource, MessageFileBelongsTo
|
||||
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.agent_service import AgentService
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
@ -119,16 +119,16 @@ class TestAgentService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "agent-chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="agent-chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -9,7 +9,7 @@ from models import Account
|
||||
from models.enums import ConversationFromSource, InvokeFrom
|
||||
from models.model import MessageAnnotation
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
|
||||
@ -86,16 +86,16 @@ class TestAnnotationService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
|
||||
# Create app
|
||||
app_service = AppService()
|
||||
|
||||
@ -37,7 +37,7 @@ from services.app_dsl_service import (
|
||||
PendingData,
|
||||
_check_version_compatibility,
|
||||
)
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
_DEFAULT_TENANT_ID = "00000000-0000-0000-0000-000000000001"
|
||||
@ -147,16 +147,16 @@ class TestAppDslService:
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
return app, account
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -133,7 +134,10 @@ class TestAppGenerateService:
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, mode="chat"
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies,
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = "chat",
|
||||
):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
@ -165,20 +169,20 @@ class TestAppGenerateService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": mode,
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
"max_active_requests": 5,
|
||||
}
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
from services.app_service import AppService
|
||||
# Create app with realistic data
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode=mode,
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
max_active_requests=5,
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.model_template import default_app_templates
|
||||
@ -12,7 +13,7 @@ from services.account_service import AccountService, TenantService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
# from services.app_service import AppService, AppListParams, CreateAppParams
|
||||
|
||||
|
||||
class TestAppService:
|
||||
@ -64,34 +65,34 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_params = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
|
||||
# Create app
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
app = app_service.create_app(tenant.id, app_params, account)
|
||||
|
||||
# Verify app was created correctly
|
||||
assert app.name == app_args["name"]
|
||||
assert app.description == app_args["description"]
|
||||
assert app.mode == app_args["mode"]
|
||||
assert app.icon_type == app_args["icon_type"]
|
||||
assert app.icon == app_args["icon"]
|
||||
assert app.icon_background == app_args["icon_background"]
|
||||
assert app.name == app_params.name
|
||||
assert app.description == app_params.description
|
||||
assert app.mode == app_params.mode
|
||||
assert app.icon_type == app_params.icon_type
|
||||
assert app.icon == app_params.icon
|
||||
assert app.icon_background == app_params.icon_background
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.api_rph == app_args["api_rph"]
|
||||
assert app.api_rpm == app_args["api_rpm"]
|
||||
assert app.api_rph == app_params.api_rph
|
||||
assert app.api_rpm == app_params.api_rpm
|
||||
assert app.created_by == account.id
|
||||
assert app.updated_by == account.id
|
||||
assert app.status == "normal"
|
||||
@ -120,7 +121,7 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
@ -129,20 +130,20 @@ class TestAppService:
|
||||
app_modes = [v.value for v in default_app_templates]
|
||||
|
||||
for mode in app_modes:
|
||||
app_args = {
|
||||
"name": f"{fake.company()} {mode}",
|
||||
"description": f"Test app for {mode} mode",
|
||||
"mode": mode,
|
||||
"icon_type": "emoji",
|
||||
"icon": "🚀",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
app_params = CreateAppParams(
|
||||
name=f"{fake.company()} {mode}",
|
||||
description=f"Test app for {mode} mode",
|
||||
mode=mode,
|
||||
icon_type="emoji",
|
||||
icon="🚀",
|
||||
icon_background="#4ECDC4",
|
||||
)
|
||||
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
app = app_service.create_app(tenant.id, app_params, account)
|
||||
|
||||
# Verify app mode was set correctly
|
||||
assert app.mode == mode
|
||||
assert app.name == app_args["name"]
|
||||
assert app.name == app_params.name
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.created_by == account.id
|
||||
|
||||
@ -163,20 +164,20 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_params = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🎯",
|
||||
icon_background="#45B7D1",
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
created_app = app_service.create_app(tenant.id, app_args, account)
|
||||
created_app = app_service.create_app(tenant.id, app_params, account)
|
||||
|
||||
# Get app using the service - needs current_user mock
|
||||
mock_current_user = create_autospec(Account, instance=True)
|
||||
@ -211,31 +212,27 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create multiple apps
|
||||
app_names = [fake.company() for _ in range(5)]
|
||||
for name in app_names:
|
||||
app_args = {
|
||||
"name": name,
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "📱",
|
||||
"icon_background": "#96CEB4",
|
||||
}
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
app_params = CreateAppParams(
|
||||
name=name,
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="📱",
|
||||
icon_background="#96CEB4",
|
||||
)
|
||||
app_service.create_app(tenant.id, app_params, account)
|
||||
|
||||
# Get paginated apps
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
}
|
||||
params = AppListParams(page=1, limit=10, mode="chat")
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
|
||||
|
||||
# Verify pagination results
|
||||
assert paginated_apps is not None
|
||||
@ -267,60 +264,47 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with different modes
|
||||
chat_app_args = {
|
||||
"name": "Chat App",
|
||||
"description": "A chat application",
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "💬",
|
||||
"icon_background": "#FF6B6B",
|
||||
}
|
||||
completion_app_args = {
|
||||
"name": "Completion App",
|
||||
"description": "A completion application",
|
||||
"mode": "completion",
|
||||
"icon_type": "emoji",
|
||||
"icon": "✍️",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
chat_app_params = CreateAppParams(
|
||||
name="Chat App",
|
||||
description="A chat application",
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="💬",
|
||||
icon_background="#FF6B6B",
|
||||
)
|
||||
completion_app_params = CreateAppParams(
|
||||
name="Completion App",
|
||||
description="A completion application",
|
||||
mode="completion",
|
||||
icon_type="emoji",
|
||||
icon="✍️",
|
||||
icon_background="#4ECDC4",
|
||||
)
|
||||
|
||||
chat_app = app_service.create_app(tenant.id, chat_app_args, account)
|
||||
completion_app = app_service.create_app(tenant.id, completion_app_args, account)
|
||||
chat_app = app_service.create_app(tenant.id, chat_app_params, account)
|
||||
completion_app = app_service.create_app(tenant.id, completion_app_params, account)
|
||||
|
||||
# Test filter by mode
|
||||
chat_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
}
|
||||
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, chat_args)
|
||||
chat_apps = app_service.get_paginate_apps(account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat"))
|
||||
assert len(chat_apps.items) == 1
|
||||
assert chat_apps.items[0].mode == "chat"
|
||||
|
||||
# Test filter by name
|
||||
name_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"name": "Chat",
|
||||
}
|
||||
filtered_apps = app_service.get_paginate_apps(account.id, tenant.id, name_args)
|
||||
filtered_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="chat", name="Chat")
|
||||
)
|
||||
assert len(filtered_apps.items) == 1
|
||||
assert "Chat" in filtered_apps.items[0].name
|
||||
|
||||
# Test filter by created_by_me
|
||||
created_by_me_args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "completion",
|
||||
"is_created_by_me": True,
|
||||
}
|
||||
my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args)
|
||||
my_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(page=1, limit=10, mode="completion", is_created_by_me=True)
|
||||
)
|
||||
assert len(my_apps.items) == 1
|
||||
|
||||
def test_get_paginate_apps_with_tag_filters(
|
||||
@ -342,34 +326,29 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create an app
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🏷️",
|
||||
"icon_background": "#FFEAA7",
|
||||
}
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
app_params = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🏷️",
|
||||
icon_background="#FFEAA7",
|
||||
)
|
||||
app = app_service.create_app(tenant.id, app_params, account)
|
||||
|
||||
# Mock TagService to return the app ID for tag filtering
|
||||
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
|
||||
mock_tag_service.return_value = [app.id]
|
||||
|
||||
# Test with tag filter
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"tag_ids": ["tag1", "tag2"],
|
||||
}
|
||||
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["tag1", "tag2"])
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
|
||||
|
||||
# Verify tag service was called
|
||||
mock_tag_service.assert_called_once_with("app", tenant.id, ["tag1", "tag2"])
|
||||
@ -383,14 +362,9 @@ class TestAppService:
|
||||
with patch("services.app_service.TagService.get_target_ids_by_tag_ids") as mock_tag_service:
|
||||
mock_tag_service.return_value = []
|
||||
|
||||
args = {
|
||||
"page": 1,
|
||||
"limit": 10,
|
||||
"mode": "chat",
|
||||
"tag_ids": ["nonexistent_tag"],
|
||||
}
|
||||
params = AppListParams(page=1, limit=10, mode="chat", tag_ids=["nonexistent_tag"])
|
||||
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, params)
|
||||
|
||||
# Should return None when no apps match tag filter
|
||||
assert paginated_apps is None
|
||||
@ -412,20 +386,20 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_params = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🎯",
|
||||
icon_background="#45B7D1",
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
app = app_service.create_app(tenant.id, app_params, account)
|
||||
|
||||
# Store original values
|
||||
original_name = app.name
|
||||
@ -481,19 +455,19 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
},
|
||||
CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🎯",
|
||||
icon_background="#45B7D1",
|
||||
),
|
||||
account,
|
||||
)
|
||||
|
||||
@ -533,19 +507,19 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
},
|
||||
CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🎯",
|
||||
icon_background="#45B7D1",
|
||||
),
|
||||
account,
|
||||
)
|
||||
|
||||
@ -584,20 +558,20 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_params = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🎯",
|
||||
icon_background="#45B7D1",
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
app = app_service.create_app(tenant.id, app_params, account)
|
||||
|
||||
# Store original name
|
||||
original_name = app.name
|
||||
@ -637,20 +611,20 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🎯",
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_params = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🎯",
|
||||
icon_background="#45B7D1",
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
app = app_service.create_app(tenant.id, app_params, account)
|
||||
|
||||
# Store original values
|
||||
original_icon = app.icon
|
||||
@ -698,18 +672,17 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🌐",
|
||||
"icon_background": "#74B9FF",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🌐",
|
||||
icon_background="#74B9FF",
|
||||
)
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -758,18 +731,17 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔌",
|
||||
"icon_background": "#A29BFE",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🔌",
|
||||
icon_background="#A29BFE",
|
||||
)
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -818,18 +790,17 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔄",
|
||||
"icon_background": "#FD79A8",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🔄",
|
||||
icon_background="#FD79A8",
|
||||
)
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -869,18 +840,17 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🗑️",
|
||||
"icon_background": "#E17055",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🗑️",
|
||||
icon_background="#E17055",
|
||||
)
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -921,18 +891,17 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🧹",
|
||||
"icon_background": "#00B894",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🧹",
|
||||
icon_background="#00B894",
|
||||
)
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -981,18 +950,17 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "📊",
|
||||
"icon_background": "#6C5CE7",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="📊",
|
||||
icon_background="#6C5CE7",
|
||||
)
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -1020,18 +988,17 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🔗",
|
||||
"icon_background": "#FDCB6E",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🔗",
|
||||
icon_background="#FDCB6E",
|
||||
)
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -1060,18 +1027,17 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app first
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🆔",
|
||||
"icon_background": "#E84393",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🆔",
|
||||
icon_background="#E84393",
|
||||
)
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -1107,26 +1073,20 @@ class TestAppService:
|
||||
password=generate_valid_password(fake),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments with invalid mode
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "invalid_mode", # Invalid mode
|
||||
"icon_type": "emoji",
|
||||
"icon": "❌",
|
||||
"icon_background": "#D63031",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import CreateAppParams
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Attempt to create app with invalid mode
|
||||
with pytest.raises(ValueError, match="invalid mode value"):
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
# Attempt to create app with invalid mode - Pydantic will reject invalid literal
|
||||
with pytest.raises(ValidationError):
|
||||
CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="invalid_mode", # type: ignore[arg-type]
|
||||
icon_type="emoji",
|
||||
icon="❌",
|
||||
icon_background="#D63031",
|
||||
)
|
||||
|
||||
def test_get_apps_with_special_characters_in_name(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -1152,99 +1112,103 @@ class TestAppService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with special characters in names
|
||||
app_with_percent = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "App with 50% discount",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
CreateAppParams(
|
||||
name="App with 50% discount",
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
),
|
||||
account,
|
||||
)
|
||||
|
||||
app_with_underscore = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "test_data_app",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
CreateAppParams(
|
||||
name="test_data_app",
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
),
|
||||
account,
|
||||
)
|
||||
|
||||
app_with_backslash = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "path\\to\\app",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
CreateAppParams(
|
||||
name="path\\to\\app",
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
),
|
||||
account,
|
||||
)
|
||||
|
||||
# Create app that should NOT match
|
||||
app_no_match = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "100% different",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
CreateAppParams(
|
||||
name="100% different",
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
),
|
||||
account,
|
||||
)
|
||||
|
||||
# Test 1: Search with % character
|
||||
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "App with 50% discount"
|
||||
|
||||
# Test 2: Search with _ character
|
||||
args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="test_data", mode="chat", page=1, limit=10)
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "test_data_app"
|
||||
|
||||
# Test 3: Search with \ character
|
||||
args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="path\\to\\app", mode="chat", page=1, limit=10)
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "path\\to\\app"
|
||||
|
||||
# Test 4: Search with % should NOT match 100% (verifies escaping works)
|
||||
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
paginated_apps = app_service.get_paginate_apps(
|
||||
account.id, tenant.id, AppListParams(name="50%", mode="chat", page=1, limit=10)
|
||||
)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert all("50%" in app.name for app in paginated_apps.items)
|
||||
|
||||
@ -13,9 +13,9 @@ from uuid import uuid4
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, SegmentStatus
|
||||
from services.dataset_service import SegmentService
|
||||
|
||||
|
||||
@ -35,13 +35,13 @@ class SegmentServiceTestDataFactory:
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
if tenant is None:
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status=TenantStatus.NORMAL)
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@ -103,7 +103,7 @@ class SegmentServiceTestDataFactory:
|
||||
created_by: str,
|
||||
position: int = 1,
|
||||
content: str = "Test content",
|
||||
status: str = "completed",
|
||||
status: SegmentStatus = SegmentStatus.COMPLETED,
|
||||
word_count: int = 10,
|
||||
tokens: int = 15,
|
||||
) -> DocumentSegment:
|
||||
@ -203,7 +203,7 @@ class TestSegmentServiceGetSegments:
|
||||
document_id=document.id,
|
||||
created_by=owner.id,
|
||||
position=1,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
@ -212,7 +212,7 @@ class TestSegmentServiceGetSegments:
|
||||
document_id=document.id,
|
||||
created_by=owner.id,
|
||||
position=2,
|
||||
status="indexing",
|
||||
status=SegmentStatus.INDEXING,
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
@ -221,7 +221,7 @@ class TestSegmentServiceGetSegments:
|
||||
document_id=document.id,
|
||||
created_by=owner.id,
|
||||
position=3,
|
||||
status="waiting",
|
||||
status=SegmentStatus.WAITING,
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -257,7 +257,7 @@ class TestSegmentServiceGetSegments:
|
||||
document_id=document.id,
|
||||
created_by=owner.id,
|
||||
position=1,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
@ -266,7 +266,7 @@ class TestSegmentServiceGetSegments:
|
||||
document_id=document.id,
|
||||
created_by=owner.id,
|
||||
position=2,
|
||||
status="indexing",
|
||||
status=SegmentStatus.INDEXING,
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -415,7 +415,7 @@ class TestSegmentServiceGetSegments:
|
||||
document_id=document.id,
|
||||
created_by=owner.id,
|
||||
position=1,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
content="This is important information",
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
@ -425,7 +425,7 @@ class TestSegmentServiceGetSegments:
|
||||
document_id=document.id,
|
||||
created_by=owner.id,
|
||||
position=2,
|
||||
status="indexing",
|
||||
status=SegmentStatus.INDEXING,
|
||||
content="This is also important",
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
@ -435,7 +435,7 @@ class TestSegmentServiceGetSegments:
|
||||
document_id=document.id,
|
||||
created_by=owner.id,
|
||||
position=3,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
content="This is irrelevant",
|
||||
)
|
||||
|
||||
@ -477,7 +477,7 @@ class TestSegmentServiceGetSegments:
|
||||
document_id=document.id,
|
||||
created_by=owner.id,
|
||||
position=1,
|
||||
status="completed",
|
||||
status=SegmentStatus.COMPLETED,
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
@ -486,7 +486,7 @@ class TestSegmentServiceGetSegments:
|
||||
document_id=document.id,
|
||||
created_by=owner.id,
|
||||
position=2,
|
||||
status="waiting",
|
||||
status=SegmentStatus.WAITING,
|
||||
)
|
||||
|
||||
# Act
|
||||
|
||||
@ -16,6 +16,7 @@ from uuid import uuid4
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models import AccountStatus, CreatorUserRole, TenantStatus
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -25,7 +26,7 @@ from models.dataset import (
|
||||
DatasetProcessRule,
|
||||
DatasetQuery,
|
||||
)
|
||||
from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode
|
||||
from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode, TagType
|
||||
from models.model import Tag, TagBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
||||
@ -42,11 +43,11 @@ class DatasetRetrievalTestDataFactory:
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
tenant = Tenant(
|
||||
name=f"tenant-{uuid4()}",
|
||||
status="normal",
|
||||
status=TenantStatus.NORMAL,
|
||||
)
|
||||
db_session_with_containers.add_all([account, tenant])
|
||||
db_session_with_containers.flush()
|
||||
@ -72,7 +73,7 @@ class DatasetRetrievalTestDataFactory:
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.flush()
|
||||
@ -130,7 +131,7 @@ class DatasetRetrievalTestDataFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_process_rule(
|
||||
db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict
|
||||
db_session_with_containers: Session, dataset_id: str, created_by: str, mode: ProcessRuleMode, rules: dict
|
||||
) -> DatasetProcessRule:
|
||||
"""Create a dataset process rule."""
|
||||
process_rule = DatasetProcessRule(
|
||||
@ -153,7 +154,7 @@ class DatasetRetrievalTestDataFactory:
|
||||
content=content,
|
||||
source=DatasetQuerySource.APP,
|
||||
source_app_id=None,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset_query)
|
||||
@ -176,7 +177,7 @@ class DatasetRetrievalTestDataFactory:
|
||||
"""Create a knowledge tag and bind it to the target dataset."""
|
||||
tag = Tag(
|
||||
tenant_id=tenant_id,
|
||||
type="knowledge",
|
||||
type=TagType.KNOWLEDGE,
|
||||
name=f"tag-{uuid4()}",
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import ConversationFromSource, FeedbackRating, InvokeFrom
|
||||
from models.model import MessageFeedback
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
LastMessageNotExistsError,
|
||||
@ -103,16 +103,16 @@ class TestMessageService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Setup app creation arguments
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "advanced-chat", # Use advanced-chat mode to use mocked workflow
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="advanced-chat", # Use advanced-chat mode to use mocked workflow,
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
|
||||
# Create app
|
||||
app_service = AppService()
|
||||
|
||||
@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from core.ops.entities.config_entity import TracingProviderEnum
|
||||
from models.model import TraceAppConfig
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
from services.ops_service import OpsService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
@ -57,14 +57,14 @@ class TestOpsService:
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
},
|
||||
CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
),
|
||||
account,
|
||||
)
|
||||
return app, account
|
||||
|
||||
@ -8,7 +8,7 @@ from models import App, CreatorUserRole
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import EndUser, Message
|
||||
from models.web import SavedMessage
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
from services.saved_message_service import SavedMessageService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
@ -73,16 +73,16 @@ class TestSavedMessageService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -11,7 +11,7 @@ from models.enums import ConversationFromSource
|
||||
from models.model import Conversation, EndUser
|
||||
from models.web import PinnedConversation
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
from services.web_conversation_service import WebConversationService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
@ -77,16 +77,16 @@ class TestWebConversationService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -17,7 +17,7 @@ from models.workflow import WorkflowAppLogCreatedFrom
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
# from services.app_service import AppService, CreateAppParams
|
||||
from services.workflow_app_service import LogView, WorkflowAppService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
@ -82,20 +82,20 @@ class TestWorkflowAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "workflow",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="workflow",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
@ -146,20 +146,20 @@ class TestWorkflowAppService:
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "workflow",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="workflow",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -13,7 +13,7 @@ from models.model import (
|
||||
)
|
||||
from models.workflow import WorkflowRun
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
@ -79,16 +79,16 @@ class TestWorkflowRunService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
@ -535,13 +535,13 @@ class TestWorkflowRunService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app
|
||||
app_args = {
|
||||
"name": "Test App",
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🚀",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name="Test App",
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🚀",
|
||||
icon_background="#4ECDC4",
|
||||
)
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Create workflow run without node executions
|
||||
@ -586,13 +586,13 @@ class TestWorkflowRunService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app
|
||||
app_args = {
|
||||
"name": "Test App",
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🚀",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name="Test App",
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🚀",
|
||||
icon_background="#4ECDC4",
|
||||
)
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Use invalid workflow run ID
|
||||
@ -637,13 +637,13 @@ class TestWorkflowRunService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app
|
||||
app_args = {
|
||||
"name": "Test App",
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🚀",
|
||||
"icon_background": "#4ECDC4",
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name="Test App",
|
||||
mode="chat",
|
||||
icon_type="emoji",
|
||||
icon="🚀",
|
||||
icon_background="#4ECDC4",
|
||||
)
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
# Create workflow run
|
||||
|
||||
@ -11,7 +11,7 @@ from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow as WorkflowModel
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.app_service import AppService, CreateAppParams
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
from tests.test_containers_integration_tests.helpers import generate_valid_password
|
||||
|
||||
@ -94,16 +94,16 @@ class TestWorkflowToolManageService:
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create app with realistic data
|
||||
app_args = {
|
||||
"name": fake.company(),
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "workflow",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
}
|
||||
app_args = CreateAppParams(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
mode="workflow",
|
||||
icon_type="emoji",
|
||||
icon="🤖",
|
||||
icon_background="#FF6B6B",
|
||||
api_rph=100,
|
||||
api_rpm=10,
|
||||
)
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -128,7 +128,6 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
for i in range(3):
|
||||
segment = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -451,7 +450,6 @@ class TestAddDocumentToIndexTask:
|
||||
segments = []
|
||||
for i in range(3):
|
||||
segment = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -630,7 +628,6 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Segment 1: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
|
||||
segment1 = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -650,7 +647,6 @@ class TestAddDocumentToIndexTask:
|
||||
# Segment 2: Should be processed (enabled=True, status=SegmentStatus.COMPLETED)
|
||||
# Note: Implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED
|
||||
segment2 = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -669,7 +665,6 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Segment 3: Should NOT be processed (enabled=False, status="processing")
|
||||
segment3 = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -688,7 +683,6 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Segment 4: Should be processed (enabled=False, status=SegmentStatus.COMPLETED)
|
||||
segment4 = DocumentSegment(
|
||||
id=fake.uuid4(),
|
||||
tenant_id=document.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
|
||||
@ -177,7 +177,6 @@ class TestBatchCleanDocumentTask:
|
||||
fake = Faker()
|
||||
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=account.current_tenant.id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
@ -290,10 +289,9 @@ class TestBatchCleanDocumentTask:
|
||||
account = self._create_test_account(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account)
|
||||
|
||||
assert account.current_tenant
|
||||
# Create segment with simple content (no image references)
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=account.current_tenant.id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
@ -692,9 +690,9 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
# Create multiple segments for the document
|
||||
segments = []
|
||||
assert account.current_tenant
|
||||
for i in range(3):
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=account.current_tenant.id,
|
||||
dataset_id=document.dataset_id,
|
||||
document_id=document.id,
|
||||
|
||||
@ -220,7 +220,6 @@ class TestCleanDatasetTask:
|
||||
DocumentSegment: Created document segment instance
|
||||
"""
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -232,8 +231,6 @@ class TestCleanDatasetTask:
|
||||
status=SegmentStatus.COMPLETED,
|
||||
index_node_id=str(uuid.uuid4()),
|
||||
index_node_hash="test_hash",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
db_session_with_containers.add(segment)
|
||||
@ -614,7 +611,6 @@ class TestCleanDatasetTask:
|
||||
"""
|
||||
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -626,8 +622,6 @@ class TestCleanDatasetTask:
|
||||
status=SegmentStatus.COMPLETED,
|
||||
index_node_id=str(uuid.uuid4()),
|
||||
index_node_hash="test_hash",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
db_session_with_containers.add(segment)
|
||||
@ -729,8 +723,6 @@ class TestCleanDatasetTask:
|
||||
type=DatasetMetadataType.STRING,
|
||||
created_by=account.id,
|
||||
)
|
||||
metadata.id = str(uuid.uuid4())
|
||||
metadata.created_at = datetime.now()
|
||||
metadata_items.append(metadata)
|
||||
|
||||
# Create binding for each metadata item
|
||||
@ -741,8 +733,6 @@ class TestCleanDatasetTask:
|
||||
document_id=documents[i % len(documents)].id,
|
||||
created_by=account.id,
|
||||
)
|
||||
binding.id = str(uuid.uuid4())
|
||||
binding.created_at = datetime.now()
|
||||
bindings.append(binding)
|
||||
|
||||
db_session_with_containers.add_all(metadata_items)
|
||||
@ -946,7 +936,6 @@ class TestCleanDatasetTask:
|
||||
long_content = "Very long content " * 100 # Long content within reasonable limits
|
||||
segment_content = f"Segment with special chars: {special_content}\n{long_content}"
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -958,8 +947,6 @@ class TestCleanDatasetTask:
|
||||
status=SegmentStatus.COMPLETED,
|
||||
index_node_id=str(uuid.uuid4()),
|
||||
index_node_hash="test_hash_" + "x" * 50, # Long hash within limits
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@ -132,11 +132,10 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
document_ids.append(document.id)
|
||||
|
||||
assert tenant
|
||||
# Create segments for each document
|
||||
for j in range(2):
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -297,10 +296,9 @@ class TestCleanNotionDocumentTask:
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
assert tenant
|
||||
# Create test segment
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -379,12 +377,11 @@ class TestCleanNotionDocumentTask:
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
assert tenant
|
||||
# Create segments without index_node_ids
|
||||
segments = []
|
||||
for i in range(3):
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -468,11 +465,10 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
documents.append(document)
|
||||
|
||||
assert tenant
|
||||
# Create segments for each document
|
||||
for j in range(2):
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -569,10 +565,9 @@ class TestCleanNotionDocumentTask:
|
||||
segment_statuses = [SegmentStatus.WAITING, SegmentStatus.INDEXING, SegmentStatus.COMPLETED, SegmentStatus.ERROR]
|
||||
segments = []
|
||||
index_node_ids = []
|
||||
|
||||
assert tenant
|
||||
for i, status in enumerate(segment_statuses):
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -665,10 +660,9 @@ class TestCleanNotionDocumentTask:
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
assert tenant
|
||||
# Create segment
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -765,12 +759,11 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
documents.append(document)
|
||||
|
||||
assert tenant
|
||||
# Create multiple segments for each document
|
||||
num_segments_per_doc = 5
|
||||
for j in range(num_segments_per_doc):
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -875,7 +868,6 @@ class TestCleanNotionDocumentTask:
|
||||
# Create segments for each document
|
||||
for j in range(3):
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=account.current_tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -984,11 +976,10 @@ class TestCleanNotionDocumentTask:
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
documents.append(document)
|
||||
|
||||
assert tenant
|
||||
# Create segments for each document
|
||||
for j in range(2):
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -1093,10 +1084,9 @@ class TestCleanNotionDocumentTask:
|
||||
# Create segments with metadata
|
||||
segments = []
|
||||
index_node_ids = []
|
||||
|
||||
assert tenant
|
||||
for i in range(3):
|
||||
segment = DocumentSegment(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user