Compare commits

..

7 Commits

Author SHA1 Message Date
-LAN- 57dc7e0b2c
chore: bump version to 1.10.1-fix.1 (#29176)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-12-09 11:05:56 +08:00
Wu Tianwei cbcdecf9a8 refactor: update useNodes import to use reactflow across multiple components (#29195) 2025-12-05 16:47:03 +08:00
kenwoodjw 7b514b1147
fix: bump pyarrow to 17.0.0, werkzeug to 3.1.4, urllib3 to 2.5.0 (#29089)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2025-12-05 12:08:25 +08:00
NFish 37d4371901
chore: upgrade React to 19.2.1,fix cve-2025-55182 (#29121)
Co-authored-by: zhsama <torvalds@linux.do>
2025-12-05 12:06:39 +08:00
NFish 0eb5b8a4eb
chore: update Next.js dev dependencies to 15.5.7 (#29120) 2025-12-05 12:05:54 +08:00
dependabot[bot] 7ee57f34ce
chore(deps): bump next from 15.5.6 to 15.5.7 in /web (#29105)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-05 12:05:00 +08:00
zyssyz123 0f99a7e3f1 Fix/app list compatible (#29123) 2025-12-04 14:56:15 +08:00
5699 changed files with 245333 additions and 438317 deletions

View File

@ -1,8 +0,0 @@
{
"enabledPlugins": {
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true
}
}

View File

@ -0,0 +1,19 @@
{
"permissions": {
"allow": [],
"deny": []
},
"env": {
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
"enabledMcpjsonServers": [
"context7",
"sequential-thinking",
"github",
"fetch",
"playwright",
"ide"
],
"enableAllProjectMcpServers": true
}

View File

@ -1,483 +0,0 @@
---
name: component-refactoring
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
---
# Dify Component Refactoring Skill
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
## Quick Reference
### Commands (run from `web/`)
Use paths relative to `web/` (e.g., `app/components/...`).
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
```bash
cd web
# Generate refactoring prompt
pnpm refactor-component <path>
# Output refactoring analysis as JSON
pnpm refactor-component <path> --json
# Generate testing prompt (after refactoring)
pnpm analyze-component <path>
# Output testing analysis as JSON
pnpm analyze-component <path> --json
```
### Complexity Analysis
```bash
# Analyze component complexity
pnpm analyze-component <path> --json
# Key metrics to check:
# - complexity: normalized score 0-100 (target < 50)
# - maxComplexity: highest single function complexity
# - lineCount: total lines (target < 300)
```
### Complexity Score Interpretation
| Score | Level | Action |
|-------|-------|--------|
| 0-25 | 🟢 Simple | Ready for testing |
| 26-50 | 🟡 Medium | Consider minor refactoring |
| 51-75 | 🟠 Complex | **Refactor before testing** |
| 76-100 | 🔴 Very Complex | **Must refactor** |
## Core Refactoring Patterns
### Pattern 1: Extract Custom Hooks
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
```typescript
// ❌ Before: Complex state logic in component
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// 50+ lines of state management logic...
return <div>...</div>
}
// ✅ After: Extract to custom hook
// hooks/use-model-config.ts
export const useModelConfig = (appId: string) => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// Related state management logic here
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
}
// Component becomes cleaner
const Configuration: FC = () => {
const { modelConfig, setModelConfig } = useModelConfig(appId)
return <div>...</div>
}
```
**Dify Examples**:
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
- `web/app/components/app/configuration/debug/hooks.tsx`
- `web/app/components/workflow/hooks/use-workflow.ts`
### Pattern 2: Extract Sub-Components
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
```typescript
// ❌ Before: Monolithic JSX with multiple sections
const AppInfo = () => {
return (
<div>
{/* 100 lines of header UI */}
{/* 100 lines of operations UI */}
{/* 100 lines of modals */}
</div>
)
}
// ✅ After: Split into focused components
// app-info/
// ├── index.tsx (orchestration only)
// ├── app-header.tsx (header UI)
// ├── app-operations.tsx (operations UI)
// └── app-modals.tsx (modal management)
const AppInfo = () => {
const { showModal, setShowModal } = useAppInfoModals()
return (
<div>
<AppHeader appDetail={appDetail} />
<AppOperations onAction={handleAction} />
<AppModals show={showModal} onClose={() => setShowModal(null)} />
</div>
)
}
```
**Dify Examples**:
- `web/app/components/app/configuration/` directory structure
- `web/app/components/workflow/nodes/` per-node organization
### Pattern 3: Simplify Conditional Logic
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
```typescript
// ❌ Before: Deeply nested conditionals
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh />
case LanguagesSupported[7]:
return <TemplateChatJa />
default:
return <TemplateChatEn />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
// Another 15 lines...
}
// More conditions...
}, [appDetail, locale])
// ✅ After: Use lookup tables + early returns
const TEMPLATE_MAP = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
// ...
},
}
const Template = useMemo(() => {
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
if (!modeTemplates) return null
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
### Pattern 4: Extract API/Data Logic
**When**: Component directly handles API calls, data transformation, or complex async operations.
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks.
```typescript
// ❌ Before: API logic in component
const MCPServiceCard = () => {
const [basicAppConfig, setBasicAppConfig] = useState({})
useEffect(() => {
if (isBasicApp && appId) {
(async () => {
const res = await fetchAppDetail({ url: '/apps', id: appId })
setBasicAppConfig(res?.model_config || {})
})()
}
}, [appId, isBasicApp])
// More API-related logic...
}
// ✅ After: Extract to data hook using React Query
// use-app-config.ts
import { useQuery } from '@tanstack/react-query'
import { get } from '@/service/base'
const NAME_SPACE = 'appConfig'
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
return useQuery({
enabled: isBasicApp && !!appId,
queryKey: [NAME_SPACE, 'detail', appId],
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || {},
})
}
// Component becomes cleaner
const MCPServiceCard = () => {
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
// UI only
}
```
**React Query Best Practices in Dify**:
- Define `NAME_SPACE` for query key organization
- Use `enabled` option for conditional fetching
- Use `select` for data transformation
- Export invalidation hooks: `useInvalidXxx`
**Dify Examples**:
- `web/service/use-workflow.ts`
- `web/service/use-common.ts`
- `web/service/knowledge/use-dataset.ts`
- `web/service/knowledge/use-document.ts`
### Pattern 5: Extract Modal/Dialog Management
**When**: Component manages multiple modals with complex open/close states.
**Dify Convention**: Modals should be extracted with their state management.
```typescript
// ❌ Before: Multiple modal states in component
const AppInfo = () => {
const [showEditModal, setShowEditModal] = useState(false)
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
const [showSwitchModal, setShowSwitchModal] = useState(false)
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
// 5+ more modal states...
}
// ✅ After: Extract to modal management hook
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
const useAppInfoModals = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
const closeModal = useCallback(() => setActiveModal(null), [])
return {
activeModal,
openModal,
closeModal,
isOpen: (type: ModalType) => activeModal === type,
}
}
```
### Pattern 6: Extract Form Logic
**When**: Complex form validation, submission handling, or field transformation.
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
```typescript
// ✅ Use existing form infrastructure
import { useAppForm } from '@/app/components/base/form'
const ConfigForm = () => {
const form = useAppForm({
defaultValues: { name: '', description: '' },
onSubmit: handleSubmit,
})
return <form.Provider>...</form.Provider>
}
```
## Dify-Specific Refactoring Guidelines
### 1. Context Provider Extraction
**When**: Component provides complex context values with multiple states.
```typescript
// ❌ Before: Large context value object
const value = {
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
// 50+ more properties...
}
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
// ✅ After: Split into domain-specific contexts
<ModelConfigProvider value={modelConfigValue}>
<DatasetConfigProvider value={datasetConfigValue}>
<UIConfigProvider value={uiConfigValue}>
{children}
</UIConfigProvider>
</DatasetConfigProvider>
</ModelConfigProvider>
```
**Dify Reference**: `web/context/` directory structure
### 2. Workflow Node Components
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
**Conventions**:
- Keep node logic in `use-interactions.ts`
- Extract panel UI to separate files
- Use `_base` components for common patterns
```
nodes/<node-type>/
├── index.tsx # Node registration
├── node.tsx # Node visual component
├── panel.tsx # Configuration panel
├── use-interactions.ts # Node-specific hooks
└── types.ts # Type definitions
```
### 3. Configuration Components
**When**: Refactoring app configuration components.
**Conventions**:
- Separate config sections into subdirectories
- Use existing patterns from `web/app/components/app/configuration/`
- Keep feature toggles in dedicated components
### 4. Tool/Plugin Components
**When**: Refactoring tool-related components (`web/app/components/tools/`).
**Conventions**:
- Follow existing modal patterns
- Use service hooks from `web/service/use-tools.ts`
- Keep provider-specific logic isolated
## Refactoring Workflow
### Step 1: Generate Refactoring Prompt
```bash
pnpm refactor-component <path>
```
This command will:
- Analyze component complexity and features
- Identify specific refactoring actions needed
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
- Provide detailed requirements based on detected patterns
### Step 2: Analyze Details
```bash
pnpm analyze-component <path> --json
```
Identify:
- Total complexity score
- Max function complexity
- Line count
- Features detected (state, effects, API, etc.)
### Step 3: Plan
Create a refactoring plan based on detected features:
| Detected Feature | Refactoring Action |
|------------------|-------------------|
| `hasState: true` + `hasEffects: true` | Extract custom hook |
| `hasAPI: true` | Extract data/service hook |
| `hasEvents: true` (many) | Extract event handlers |
| `lineCount > 300` | Split into sub-components |
| `maxComplexity > 50` | Simplify conditional logic |
### Step 4: Execute Incrementally
1. **Extract one piece at a time**
2. **Run lint, type-check, and tests after each extraction**
3. **Verify functionality before next step**
```
For each extraction:
┌────────────────────────────────────────┐
│ 1. Extract code │
│ 2. Run: pnpm lint:fix │
│ 3. Run: pnpm type-check:tsgo │
│ 4. Run: pnpm test │
│ 5. Test functionality manually │
│ 6. PASS? → Next extraction │
│ FAIL? → Fix before continuing │
└────────────────────────────────────────┘
```
### Step 5: Verify
After refactoring:
```bash
# Re-run refactor command to verify improvements
pnpm refactor-component <path>
# If complexity < 25 and lines < 200, you'll see:
# ✅ COMPONENT IS WELL-STRUCTURED
# For detailed metrics:
pnpm analyze-component <path> --json
# Target metrics:
# - complexity < 50
# - lineCount < 300
# - maxComplexity < 30
```
## Common Mistakes to Avoid
### ❌ Over-Engineering
```typescript
// ❌ Too many tiny hooks
const useButtonText = () => useState('Click')
const useButtonDisabled = () => useState(false)
const useButtonLoading = () => useState(false)
// ✅ Cohesive hook with related state
const useButtonState = () => {
const [text, setText] = useState('Click')
const [disabled, setDisabled] = useState(false)
const [loading, setLoading] = useState(false)
return { text, setText, disabled, setDisabled, loading, setLoading }
}
```
### ❌ Breaking Existing Patterns
- Follow existing directory structures
- Maintain naming conventions
- Preserve export patterns for compatibility
### ❌ Premature Abstraction
- Only extract when there's clear complexity benefit
- Don't create abstractions for single-use code
- Keep refactored code in the same domain area
## References
### Dify Codebase Examples
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
- **Component splitting**: `web/app/components/app/configuration/`
- **Service hooks**: `web/service/use-*.ts`
- **Workflow patterns**: `web/app/components/workflow/hooks/`
- **Form patterns**: `web/app/components/base/form/`
### Related Skills
- `frontend-testing` - For testing refactored components
- `web/testing/testing.md` - Testing specification

View File

@ -1,493 +0,0 @@
# Complexity Reduction Patterns
This document provides patterns for reducing cognitive complexity in Dify React components.
## Understanding Complexity
### SonarJS Cognitive Complexity
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
- **Total Complexity**: Sum of all functions' complexity in the file
- **Max Complexity**: Highest single function complexity
### What Increases Complexity
| Pattern | Complexity Impact |
|---------|-------------------|
| `if/else` | +1 per branch |
| Nested conditions | +1 per nesting level |
| `switch/case` | +1 per case |
| `for/while/do` | +1 per loop |
| `&&`/`||` chains | +1 per operator |
| Nested callbacks | +1 per nesting level |
| `try/catch` | +1 per catch |
| Ternary expressions | +1 per nesting |
## Pattern 1: Replace Conditionals with Lookup Tables
**Before** (complexity: ~15):
```typescript
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateChatJa appDetail={appDetail} />
default:
return <TemplateChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateAdvancedChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateAdvancedChatJa appDetail={appDetail} />
default:
return <TemplateAdvancedChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
// Similar pattern...
}
return null
}, [appDetail, locale])
```
**After** (complexity: ~3):
```typescript
// Define lookup table outside component
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
default: TemplateAdvancedChatEn,
},
[AppModeEnum.WORKFLOW]: {
[LanguagesSupported[1]]: TemplateWorkflowZh,
[LanguagesSupported[7]]: TemplateWorkflowJa,
default: TemplateWorkflowEn,
},
// ...
}
// Clean component logic
const Template = useMemo(() => {
if (!appDetail?.mode) return null
const templates = TEMPLATE_MAP[appDetail.mode]
if (!templates) return null
const TemplateComponent = templates[locale] ?? templates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
## Pattern 2: Use Early Returns
**Before** (complexity: ~10):
```typescript
const handleSubmit = () => {
if (isValid) {
if (hasChanges) {
if (isConnected) {
submitData()
} else {
showConnectionError()
}
} else {
showNoChangesMessage()
}
} else {
showValidationError()
}
}
```
**After** (complexity: ~4):
```typescript
const handleSubmit = () => {
if (!isValid) {
showValidationError()
return
}
if (!hasChanges) {
showNoChangesMessage()
return
}
if (!isConnected) {
showConnectionError()
return
}
submitData()
}
```
## Pattern 3: Extract Complex Conditions
**Before** (complexity: high):
```typescript
const canPublish = (() => {
if (mode !== AppModeEnum.COMPLETION) {
if (!isAdvancedMode)
return true
if (modelModeType === ModelModeType.completion) {
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
return false
return true
}
return true
}
return !promptEmpty
})()
```
**After** (complexity: lower):
```typescript
// Extract to named functions
const canPublishInCompletionMode = () => !promptEmpty
const canPublishInChatMode = () => {
if (!isAdvancedMode) return true
if (modelModeType !== ModelModeType.completion) return true
return hasSetBlockStatus.history && hasSetBlockStatus.query
}
// Clean main logic
const canPublish = mode === AppModeEnum.COMPLETION
? canPublishInCompletionMode()
: canPublishInChatMode()
```
## Pattern 4: Replace Chained Ternaries
**Before** (complexity: ~5):
```typescript
const statusText = serverActivated
? t('status.running')
: serverPublished
? t('status.inactive')
: appUnpublished
? t('status.unpublished')
: t('status.notConfigured')
```
**After** (complexity: ~2):
```typescript
const getStatusText = () => {
if (serverActivated) return t('status.running')
if (serverPublished) return t('status.inactive')
if (appUnpublished) return t('status.unpublished')
return t('status.notConfigured')
}
const statusText = getStatusText()
```
Or use lookup:
```typescript
const STATUS_TEXT_MAP = {
running: 'status.running',
inactive: 'status.inactive',
unpublished: 'status.unpublished',
notConfigured: 'status.notConfigured',
} as const
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
if (serverActivated) return 'running'
if (serverPublished) return 'inactive'
if (appUnpublished) return 'unpublished'
return 'notConfigured'
}
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
```
## Pattern 5: Flatten Nested Loops
**Before** (complexity: high):
```typescript
const processData = (items: Item[]) => {
const results: ProcessedItem[] = []
for (const item of items) {
if (item.isValid) {
for (const child of item.children) {
if (child.isActive) {
for (const prop of child.properties) {
if (prop.value !== null) {
results.push({
itemId: item.id,
childId: child.id,
propValue: prop.value,
})
}
}
}
}
}
}
return results
}
```
**After** (complexity: lower):
```typescript
// Use functional approach
const processData = (items: Item[]) => {
return items
.filter(item => item.isValid)
.flatMap(item =>
item.children
.filter(child => child.isActive)
.flatMap(child =>
child.properties
.filter(prop => prop.value !== null)
.map(prop => ({
itemId: item.id,
childId: child.id,
propValue: prop.value,
}))
)
)
}
```
## Pattern 6: Extract Event Handler Logic
**Before** (complexity: high in component):
```typescript
const Component = () => {
const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
let newDatasets = data
if (data.find(item => !item.name)) {
const newSelected = produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const newItem = dataSets.find(i => i.id === item.id)
if (newItem)
draft[index] = newItem
}
})
})
setDataSets(newSelected)
newDatasets = newSelected
}
else {
setDataSets(data)
}
hideSelectDataSet()
// 40 more lines of logic...
}
return <div>...</div>
}
```
**After** (complexity: lower):
```typescript
// Extract to hook or utility
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
const normalizeSelection = (data: DataSet[]) => {
const hasUnloadedItem = data.some(item => !item.name)
if (!hasUnloadedItem) return data
return produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const existing = dataSets.find(i => i.id === item.id)
if (existing) draft[index] = existing
}
})
})
}
const hasSelectionChanged = (newData: DataSet[]) => {
return !isEqual(
newData.map(item => item.id),
dataSets.map(item => item.id)
)
}
return { normalizeSelection, hasSelectionChanged }
}
// Component becomes cleaner
const Component = () => {
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
const handleSelect = (data: DataSet[]) => {
if (!hasSelectionChanged(data)) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
const normalized = normalizeSelection(data)
setDataSets(normalized)
hideSelectDataSet()
}
return <div>...</div>
}
```
## Pattern 7: Reduce Boolean Logic Complexity
**Before** (complexity: ~8):
```typescript
const toggleDisabled = hasInsufficientPermissions
|| appUnpublished
|| missingStartNode
|| triggerModeDisabled
|| (isAdvancedApp && !currentWorkflow?.graph)
|| (isBasicApp && !basicAppConfig.updated_at)
```
**After** (complexity: ~3):
```typescript
// Extract meaningful boolean functions
const isAppReady = () => {
if (isAdvancedApp) return !!currentWorkflow?.graph
return !!basicAppConfig.updated_at
}
const hasRequiredPermissions = () => {
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
}
const canToggle = () => {
if (!hasRequiredPermissions()) return false
if (!isAppReady()) return false
if (missingStartNode) return false
if (triggerModeDisabled) return false
return true
}
const toggleDisabled = !canToggle()
```
## Pattern 8: Simplify useMemo/useCallback Dependencies
**Before** (complexity: multiple recalculations):
```typescript
const payload = useMemo(() => {
let parameters: Parameter[] = []
let outputParameters: OutputParameter[] = []
if (!published) {
parameters = (inputs || []).map((item) => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
outputParameters = (outputs || []).map((item) => ({
name: item.variable,
description: '',
type: item.value_type,
}))
}
else if (detail && detail.tool) {
parameters = (inputs || []).map((item) => ({
// Complex transformation...
}))
outputParameters = (outputs || []).map((item) => ({
// Complex transformation...
}))
}
return {
icon: detail?.icon || icon,
label: detail?.label || name,
// ...more fields
}
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
```
**After** (complexity: separated concerns):
```typescript
// Separate transformations
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
return useMemo(() => {
if (!published) {
return inputs.map(item => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
}
if (!detail?.tool) return []
return inputs.map(item => ({
name: item.variable,
required: item.required,
type: item.type === 'paragraph' ? 'string' : item.type,
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
}))
}, [inputs, detail, published])
}
// Component uses hook
const parameters = useParameterTransform(inputs, detail, published)
const outputParameters = useOutputTransform(outputs, detail, published)
const payload = useMemo(() => ({
icon: detail?.icon || icon,
label: detail?.label || name,
parameters,
outputParameters,
// ...
}), [detail, icon, name, parameters, outputParameters])
```
## Target Metrics After Refactoring
| Metric | Target |
|--------|--------|
| Total Complexity | < 50 |
| Max Function Complexity | < 30 |
| Function Length | < 30 lines |
| Nesting Depth | ≤ 3 levels |
| Conditional Chains | ≤ 3 conditions |

View File

@ -1,477 +0,0 @@
# Component Splitting Patterns
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
## When to Split Components
Split a component when you identify:
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
1. **Repeated patterns** - Similar UI structures used multiple times
1. **300+ lines** - Component exceeds manageable size
1. **Modal clusters** - Multiple modals rendered in one component
## Splitting Strategies
### Strategy 1: Section-Based Splitting
Identify visual sections and extract each as a component.
```typescript
// ❌ Before: Monolithic component (500+ lines)
const ConfigurationPage = () => {
return (
<div>
{/* Header Section - 50 lines */}
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher ... />
</div>
</div>
{/* Config Section - 200 lines */}
<div className="config">
<Config />
</div>
{/* Debug Section - 150 lines */}
<div className="debug">
<Debug ... />
</div>
{/* Modals Section - 100 lines */}
{showSelectDataSet && <SelectDataSet ... />}
{showHistoryModal && <EditHistoryModal ... />}
{showUseGPT4Confirm && <Confirm ... />}
</div>
)
}
// ✅ After: Split into focused components
// configuration/
// ├── index.tsx (orchestration)
// ├── configuration-header.tsx
// ├── configuration-content.tsx
// ├── configuration-debug.tsx
// └── configuration-modals.tsx
// configuration-header.tsx
interface ConfigurationHeaderProps {
isAdvancedMode: boolean
onPublish: () => void
}
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
isAdvancedMode,
onPublish,
}) => {
const { t } = useTranslation()
return (
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher onPublish={onPublish} />
</div>
</div>
)
}
// index.tsx (orchestration only)
const ConfigurationPage = () => {
const { modelConfig, setModelConfig } = useModelConfig()
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
<ConfigurationHeader
isAdvancedMode={isAdvancedMode}
onPublish={handlePublish}
/>
<ConfigurationContent
modelConfig={modelConfig}
onConfigChange={setModelConfig}
/>
{!isMobile && (
<ConfigurationDebug
inputs={inputs}
onSetting={handleSetting}
/>
)}
<ConfigurationModals
activeModal={activeModal}
onClose={closeModal}
/>
</div>
)
}
```
### Strategy 2: Conditional Block Extraction
Extract large conditional rendering blocks.
```typescript
// ❌ Before: Large conditional blocks
const AppInfo = () => {
return (
<div>
{expand ? (
<div className="expanded">
{/* 100 lines of expanded view */}
</div>
) : (
<div className="collapsed">
{/* 50 lines of collapsed view */}
</div>
)}
</div>
)
}
// ✅ After: Separate view components
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="expanded">
{/* Clean, focused expanded view */}
</div>
)
}
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="collapsed">
{/* Clean, focused collapsed view */}
</div>
)
}
const AppInfo = () => {
return (
<div>
{expand
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
}
</div>
)
}
```
### Strategy 3: Modal Extraction
Extract modals with their trigger logic.
```typescript
// ❌ Before: Multiple modals in one component
const AppInfo = () => {
const [showEdit, setShowEdit] = useState(false)
const [showDuplicate, setShowDuplicate] = useState(false)
const [showDelete, setShowDelete] = useState(false)
const [showSwitch, setShowSwitch] = useState(false)
const onEdit = async (data) => { /* 20 lines */ }
const onDuplicate = async (data) => { /* 20 lines */ }
const onDelete = async () => { /* 15 lines */ }
return (
<div>
{/* Main content */}
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
{showSwitch && <SwitchModal ... />}
</div>
)
}
// ✅ After: Modal manager component
// app-info-modals.tsx
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
interface AppInfoModalsProps {
appDetail: AppDetail
activeModal: ModalType
onClose: () => void
onSuccess: () => void
}
const AppInfoModals: FC<AppInfoModalsProps> = ({
appDetail,
activeModal,
onClose,
onSuccess,
}) => {
const handleEdit = async (data) => { /* logic */ }
const handleDuplicate = async (data) => { /* logic */ }
const handleDelete = async () => { /* logic */ }
return (
<>
{activeModal === 'edit' && (
<EditModal
appDetail={appDetail}
onConfirm={handleEdit}
onClose={onClose}
/>
)}
{activeModal === 'duplicate' && (
<DuplicateModal
appDetail={appDetail}
onConfirm={handleDuplicate}
onClose={onClose}
/>
)}
{activeModal === 'delete' && (
<DeleteConfirm
onConfirm={handleDelete}
onClose={onClose}
/>
)}
{activeModal === 'switch' && (
<SwitchModal
appDetail={appDetail}
onClose={onClose}
/>
)}
</>
)
}
// Parent component
const AppInfo = () => {
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
{/* Main content with openModal triggers */}
<Button onClick={() => openModal('edit')}>Edit</Button>
<AppInfoModals
appDetail={appDetail}
activeModal={activeModal}
onClose={closeModal}
onSuccess={handleSuccess}
/>
</div>
)
}
```
### Strategy 4: List Item Extraction
Extract repeated item rendering.
```typescript
// ❌ Before: Inline item rendering
const OperationsList = () => {
return (
<div>
{operations.map(op => (
<div key={op.id} className="operation-item">
<span className="icon">{op.icon}</span>
<span className="title">{op.title}</span>
<span className="description">{op.description}</span>
<button onClick={() => op.onClick()}>
{op.actionLabel}
</button>
{op.badge && <Badge>{op.badge}</Badge>}
{/* More complex rendering... */}
</div>
))}
</div>
)
}
// ✅ After: Extracted item component
interface OperationItemProps {
operation: Operation
onAction: (id: string) => void
}
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
return (
<div className="operation-item">
<span className="icon">{operation.icon}</span>
<span className="title">{operation.title}</span>
<span className="description">{operation.description}</span>
<button onClick={() => onAction(operation.id)}>
{operation.actionLabel}
</button>
{operation.badge && <Badge>{operation.badge}</Badge>}
</div>
)
}
const OperationsList = () => {
const handleAction = useCallback((id: string) => {
const op = operations.find(o => o.id === id)
op?.onClick()
}, [operations])
return (
<div>
{operations.map(op => (
<OperationItem
key={op.id}
operation={op}
onAction={handleAction}
/>
))}
</div>
)
}
```
## Directory Structure Patterns
### Pattern A: Flat Structure (Simple Components)
For components with 2-3 sub-components:
```
component-name/
├── index.tsx # Main component
├── sub-component-a.tsx
├── sub-component-b.tsx
└── types.ts # Shared types
```
### Pattern B: Nested Structure (Complex Components)
For components with many sub-components:
```
component-name/
├── index.tsx # Main orchestration
├── types.ts # Shared types
├── hooks/
│ ├── use-feature-a.ts
│ └── use-feature-b.ts
├── components/
│ ├── header/
│ │ └── index.tsx
│ ├── content/
│ │ └── index.tsx
│ └── modals/
│ └── index.tsx
└── utils/
└── helpers.ts
```
### Pattern C: Feature-Based Structure (Dify Standard)
Following Dify's existing patterns:
```
configuration/
├── index.tsx # Main page component
├── base/ # Base/shared components
│ ├── feature-panel/
│ ├── group-name/
│ └── operation-btn/
├── config/ # Config section
│ ├── index.tsx
│ ├── agent/
│ └── automatic/
├── dataset-config/ # Dataset section
│ ├── index.tsx
│ ├── card-item/
│ └── params-config/
├── debug/ # Debug section
│ ├── index.tsx
│ └── hooks.tsx
└── hooks/ # Shared hooks
└── use-advanced-prompt-config.ts
```
## Props Design
### Minimal Props Principle
Pass only what's needed:
```typescript
// ❌ Bad: Passing entire objects when only some fields needed
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
// ✅ Good: Destructure to minimum required
<ConfigHeader
appName={appDetail.name}
isAdvancedMode={modelConfig.isAdvanced}
onPublish={handlePublish}
/>
```
### Callback Props Pattern
Use callbacks for child-to-parent communication:
```typescript
// Parent
const Parent = () => {
const [value, setValue] = useState('')
return (
<Child
value={value}
onChange={setValue}
onSubmit={handleSubmit}
/>
)
}
// Child
interface ChildProps {
value: string
onChange: (value: string) => void
onSubmit: () => void
}
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
return (
<div>
<input value={value} onChange={e => onChange(e.target.value)} />
<button onClick={onSubmit}>Submit</button>
</div>
)
}
```
### Render Props for Flexibility
When sub-components need parent context:
```typescript
interface ListProps<T> {
items: T[]
renderItem: (item: T, index: number) => React.ReactNode
renderEmpty?: () => React.ReactNode
}
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
if (items.length === 0 && renderEmpty) {
return <>{renderEmpty()}</>
}
return (
<div>
{items.map((item, index) => renderItem(item, index))}
</div>
)
}
// Usage
<List
items={operations}
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
renderEmpty={() => <EmptyState message="No operations" />}
/>
```

View File

@ -1,317 +0,0 @@
# Hook Extraction Patterns
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
## When to Extract Hooks
Extract a custom hook when you identify:
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
1. **Business logic** - Data transformations, validations, or calculations
1. **Reusable patterns** - Logic that appears in multiple components
## Extraction Process
### Step 1: Identify State Groups
Look for state variables that are logically related:
```typescript
// ❌ These belong together - extract to hook
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
// These are model-related state that should be in useModelConfig()
```
### Step 2: Identify Related Effects
Find effects that modify the grouped state:
```typescript
// ❌ These effects belong with the state above
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties.mode
if (mode) {
const newModelConfig = produce(modelConfig, (draft) => {
draft.mode = mode
})
setModelConfig(newModelConfig)
}
}
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
```
### Step 3: Create the Hook
```typescript
// hooks/use-model-config.ts
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { ModelConfig } from '@/models/debug'
import { produce } from 'immer'
import { useEffect, useState } from 'react'
import { ModelModeType } from '@/types/app'
interface UseModelConfigParams {
initialConfig?: Partial<ModelConfig>
currModel?: { model_properties?: { mode?: ModelModeType } }
hasFetchedDetail: boolean
}
interface UseModelConfigReturn {
modelConfig: ModelConfig
setModelConfig: (config: ModelConfig) => void
completionParams: FormValue
setCompletionParams: (params: FormValue) => void
modelModeType: ModelModeType
}
export const useModelConfig = ({
initialConfig,
currModel,
hasFetchedDetail,
}: UseModelConfigParams): UseModelConfigReturn => {
const [modelConfig, setModelConfig] = useState<ModelConfig>({
provider: 'langgenius/openai/openai',
model_id: 'gpt-3.5-turbo',
mode: ModelModeType.unset,
// ... default values
...initialConfig,
})
const [completionParams, setCompletionParams] = useState<FormValue>({})
const modelModeType = modelConfig.mode
// Fill old app data missing model mode
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties?.mode
if (mode) {
setModelConfig(produce(modelConfig, (draft) => {
draft.mode = mode
}))
}
}
}, [hasFetchedDetail, modelModeType, currModel])
return {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
}
}
```
### Step 4: Update Component
```typescript
// Before: 50+ lines of state management
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
// ... lots of related state and effects
}
// After: Clean component
const Configuration: FC = () => {
const {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
} = useModelConfig({
currModel,
hasFetchedDetail,
})
// Component now focuses on UI
}
```
## Naming Conventions
### Hook Names
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
- Include domain: `useWorkflowVariables`, `useMCPServer`
### File Names
- Kebab-case: `use-model-config.ts`
- Place in `hooks/` subdirectory when multiple hooks exist
- Place alongside component for single-use hooks
### Return Type Names
- Suffix with `Return`: `UseModelConfigReturn`
- Suffix params with `Params`: `UseModelConfigParams`
## Common Hook Patterns in Dify
### 1. Data Fetching Hook (React Query)
```typescript
// Pattern: Use @tanstack/react-query for data fetching
import { useQuery, useQueryClient } from '@tanstack/react-query'
import { get } from '@/service/base'
import { useInvalid } from '@/service/use-base'
const NAME_SPACE = 'appConfig'
// Query keys for cache management
export const appConfigQueryKeys = {
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
}
// Main data hook
export const useAppConfig = (appId: string) => {
return useQuery({
enabled: !!appId,
queryKey: appConfigQueryKeys.detail(appId),
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || null,
})
}
// Invalidation hook for refreshing data
export const useInvalidAppConfig = () => {
return useInvalid([NAME_SPACE])
}
// Usage in component
const Component = () => {
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
const invalidAppConfig = useInvalidAppConfig()
const handleRefresh = () => {
invalidAppConfig() // Invalidates cache and triggers refetch
}
return <div>...</div>
}
```
### 2. Form State Hook
```typescript
// Pattern: Form state + validation + submission
export const useConfigForm = (initialValues: ConfigFormValues) => {
const [values, setValues] = useState(initialValues)
const [errors, setErrors] = useState<Record<string, string>>({})
const [isSubmitting, setIsSubmitting] = useState(false)
const validate = useCallback(() => {
const newErrors: Record<string, string> = {}
if (!values.name) newErrors.name = 'Name is required'
setErrors(newErrors)
return Object.keys(newErrors).length === 0
}, [values])
const handleChange = useCallback((field: string, value: any) => {
setValues(prev => ({ ...prev, [field]: value }))
}, [])
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
if (!validate()) return
setIsSubmitting(true)
try {
await onSubmit(values)
} finally {
setIsSubmitting(false)
}
}, [values, validate])
return { values, errors, isSubmitting, handleChange, handleSubmit }
}
```
### 3. Modal State Hook
```typescript
// Pattern: Multiple modal management
type ModalType = 'edit' | 'delete' | 'duplicate' | null
export const useModalState = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const [modalData, setModalData] = useState<any>(null)
const openModal = useCallback((type: ModalType, data?: any) => {
setActiveModal(type)
setModalData(data)
}, [])
const closeModal = useCallback(() => {
setActiveModal(null)
setModalData(null)
}, [])
return {
activeModal,
modalData,
openModal,
closeModal,
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
}
}
```
### 4. Toggle/Boolean Hook
```typescript
// Pattern: Boolean state with convenience methods
export const useToggle = (initialValue = false) => {
const [value, setValue] = useState(initialValue)
const toggle = useCallback(() => setValue(v => !v), [])
const setTrue = useCallback(() => setValue(true), [])
const setFalse = useCallback(() => setValue(false), [])
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
}
// Usage
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
```
## Testing Extracted Hooks
After extraction, test hooks in isolation:
```typescript
// use-model-config.spec.ts
import { renderHook, act } from '@testing-library/react'
import { useModelConfig } from './use-model-config'
describe('useModelConfig', () => {
it('should initialize with default values', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: false,
}))
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
expect(result.current.modelModeType).toBe(ModelModeType.unset)
})
it('should update model config', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: true,
}))
act(() => {
result.current.setModelConfig({
...result.current.modelConfig,
model_id: 'gpt-4',
})
})
expect(result.current.modelConfig.model_id).toBe('gpt-4')
})
})
```

View File

@ -1,322 +0,0 @@
---
name: frontend-testing
description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
---
# Dify Frontend Testing Skill
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
## When to Apply This Skill
Apply this skill when the user:
- Asks to **write tests** for a component, hook, or utility
- Asks to **review existing tests** for completeness
- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
- Requests **test coverage** improvement
- Uses `pnpm analyze-component` output as context
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
- Wants to understand **testing patterns** in the Dify codebase
**Do NOT apply** when:
- User is asking about backend/API tests (Python/pytest)
- User is asking about E2E tests (Playwright/Cypress)
- User is only asking conceptual questions without code context
## Quick Reference
### Tech Stack
| Tool | Version | Purpose |
|------|---------|---------|
| Vitest | 4.0.16 | Test runner |
| React Testing Library | 16.0 | Component testing |
| jsdom | - | Test environment |
| nock | 14.0 | HTTP mocking |
| TypeScript | 5.x | Type safety |
### Key Commands
```bash
# Run all tests
pnpm test
# Watch mode
pnpm test:watch
# Run specific file
pnpm test path/to/file.spec.tsx
# Generate coverage report
pnpm test:coverage
# Analyze component complexity
pnpm analyze-component <path>
# Review existing test
pnpm analyze-component <path> --review
```
### File Naming
- Test files: `ComponentName.spec.tsx` (same directory as component)
- Integration tests: `web/__tests__/` directory
## Test Structure Template
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import Component from './index'
// ✅ Import real project components (DO NOT mock these)
// import Loading from '@/app/components/base/loading'
// import { ChildComponent } from './child-component'
// ✅ Mock external dependencies only
vi.mock('@/service/api')
vi.mock('next/navigation', () => ({
useRouter: () => ({ push: vi.fn() }),
usePathname: () => '/test',
}))
// Shared state for mocks (if needed)
let mockSharedState = false
describe('ComponentName', () => {
beforeEach(() => {
vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
mockSharedState = false // ✅ Reset shared state
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange
const props = { title: 'Test' }
// Act
render(<Component {...props} />)
// Assert
expect(screen.getByText('Test')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should apply custom className', () => {
render(<Component className="custom" />)
expect(screen.getByRole('button')).toHaveClass('custom')
})
})
// User Interactions
describe('User Interactions', () => {
it('should handle click events', () => {
const handleClick = vi.fn()
render(<Component onClick={handleClick} />)
fireEvent.click(screen.getByRole('button'))
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle null data', () => {
render(<Component data={null} />)
expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle empty array', () => {
render(<Component items={[]} />)
expect(screen.getByText(/empty/i)).toBeInTheDocument()
})
})
})
```
## Testing Workflow (CRITICAL)
### ⚠️ Incremental Approach Required
**NEVER generate all test files at once.** For complex components or multi-file directories:
1. **Analyze & Plan**: List all files, order by complexity (simple → complex)
1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next
1. **Verify before proceeding**: Do NOT continue to next file until current passes
```
For each file:
┌────────────────────────────────────────┐
│ 1. Write test │
│ 2. Run: pnpm test <file>.spec.tsx │
│ 3. PASS? → Mark complete, next file │
│ FAIL? → Fix first, then continue │
└────────────────────────────────────────┘
```
### Complexity-Based Order
Process in this order for multi-file testing:
1. 🟢 Utility functions (simplest)
1. 🟢 Custom hooks
1. 🟡 Simple components (presentational)
1. 🟡 Medium components (state, effects)
1. 🔴 Complex components (API, routing)
1. 🔴 Integration tests (index files - last)
### When to Refactor First
- **Complexity > 50**: Break into smaller pieces before testing
- **500+ lines**: Consider splitting before testing
- **Many dependencies**: Extract logic into hooks first
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
## Testing Strategy
### Path-Level Testing (Directory Testing)
When assigned to test a directory/path, test **ALL content** within that path:
- Test all components, hooks, utilities in the directory (not just `index` file)
- Use incremental approach: one file at a time, verify each before proceeding
- Goal: 100% coverage of ALL files in the directory
### Integration Testing First
**Prefer integration testing** when writing tests for a directory:
- ✅ **Import real project components** directly (including base components and siblings)
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
- ❌ **DO NOT mock** sibling/child components in the same directory
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
## Core Principles
### 1. AAA Pattern (Arrange-Act-Assert)
Every test should clearly separate:
- **Arrange**: Setup test data and render component
- **Act**: Perform user actions
- **Assert**: Verify expected outcomes
### 2. Black-Box Testing
- Test observable behavior, not implementation details
- Use semantic queries (getByRole, getByLabelText)
- Avoid testing internal state directly
- **Prefer pattern matching over hardcoded strings** in assertions:
```typescript
// ❌ Avoid: hardcoded text assertions
expect(screen.getByText('Loading...')).toBeInTheDocument()
// ✅ Better: role-based queries
expect(screen.getByRole('status')).toBeInTheDocument()
// ✅ Better: pattern matching
expect(screen.getByText(/loading/i)).toBeInTheDocument()
```
### 3. Single Behavior Per Test
Each test verifies ONE user-observable behavior:
```typescript
// ✅ Good: One behavior
it('should disable button when loading', () => {
render(<Button loading />)
expect(screen.getByRole('button')).toBeDisabled()
})
// ❌ Bad: Multiple behaviors
it('should handle loading state', () => {
render(<Button loading />)
expect(screen.getByRole('button')).toBeDisabled()
expect(screen.getByText('Loading...')).toBeInTheDocument()
expect(screen.getByRole('button')).toHaveClass('loading')
})
```
### 4. Semantic Naming
Use `should <behavior> when <condition>`:
```typescript
it('should show error message when validation fails')
it('should call onSubmit when form is valid')
it('should disable input when isReadOnly is true')
```
## Required Test Scenarios
### Always Required (All Components)
1. **Rendering**: Component renders without crashing
1. **Props**: Required props, optional props, default values
1. **Edge Cases**: null, undefined, empty values, boundary conditions
### Conditional (When Present)
| Feature | Test Focus |
|---------|-----------|
| `useState` | Initial state, transitions, cleanup |
| `useEffect` | Execution, dependencies, cleanup |
| Event handlers | All onClick, onChange, onSubmit, keyboard |
| API calls | Loading, success, error states |
| Routing | Navigation, params, query strings |
| `useCallback`/`useMemo` | Referential equality |
| Context | Provider values, consumer behavior |
| Forms | Validation, submission, error display |
## Coverage Goals (Per File)
For each test file generated, aim for:
- ✅ **100%** function coverage
- ✅ **100%** statement coverage
- ✅ **>95%** branch coverage
- ✅ **>95%** line coverage
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
## Detailed Guides
For more detailed information, refer to:
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `references/mocking.md` - Mock patterns and best practices
- `references/async-testing.md` - Async operations and API calls
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
- `references/common-patterns.md` - Frequently used testing patterns
- `references/checklist.md` - Test generation checklist and validation steps
## Authoritative References
### Primary Specification (MUST follow)
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
### Reference Examples in Codebase
- `web/utils/classnames.spec.ts` - Utility function tests
- `web/app/components/base/button/index.spec.tsx` - Component tests
- `web/__mocks__/provider-context.ts` - Mock factory example
### Project Configuration
- `web/vitest.config.ts` - Vitest configuration
- `web/vitest.setup.ts` - Test environment setup
- `web/scripts/analyze-component.js` - Component analysis tool
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.

View File

@ -1,296 +0,0 @@
/**
* Test Template for React Components
*
* WHY THIS STRUCTURE?
* - Organized sections make tests easy to navigate and maintain
* - Mocks at top ensure consistent test isolation
* - Factory functions reduce duplication and improve readability
* - describe blocks group related scenarios for better debugging
*
* INSTRUCTIONS:
* 1. Replace `ComponentName` with your component name
* 2. Update import path
* 3. Add/remove test sections based on component features (use analyze-component)
* 4. Follow AAA pattern: Arrange Act Assert
*
* RUN FIRST: pnpm analyze-component <path> to identify required test scenarios
*/
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
// import ComponentName from './index'
// ============================================================================
// Mocks
// ============================================================================
// WHY: Mocks must be hoisted to top of file (Vitest requirement).
// They run BEFORE imports, so keep them before component imports.
// i18n (automatically mocked)
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
// No explicit mock needed - it returns translation keys as-is
// Override only if custom translations are required:
// vi.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
// 'my.custom.key': 'Custom Translation',
// }
// return customTranslations[key] || key
// },
// }),
// }))
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = vi.fn()
// vi.mock('next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))
// API services (if component fetches data)
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
// vi.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = vi.mocked(api)
// Shared mock state (for portal/dropdown components)
// WHY: Portal components like PortalToFollowElem need shared state between
// parent and child mocks to correctly simulate open/close behavior
// let mockOpenState = false
// ============================================================================
// Test Data Factories
// ============================================================================
// WHY FACTORIES?
// - Avoid hard-coded test data scattered across tests
// - Easy to create variations with overrides
// - Type-safe when using actual types from source
// - Single source of truth for default test values
// const createMockProps = (overrides = {}) => ({
// // Default props that make component render successfully
// ...overrides,
// })
// const createMockItem = (overrides = {}) => ({
// id: 'item-1',
// name: 'Test Item',
// ...overrides,
// })
// ============================================================================
// Test Helpers
// ============================================================================
// const renderComponent = (props = {}) => {
// return render(<ComponentName {...createMockProps(props)} />)
// }
// ============================================================================
// Tests
// ============================================================================
describe('ComponentName', () => {
// WHY beforeEach with clearAllMocks?
// - Ensures each test starts with clean slate
// - Prevents mock call history from leaking between tests
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
beforeEach(() => {
vi.clearAllMocks()
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
// mockOpenState = false
})
// --------------------------------------------------------------------------
// Rendering Tests (REQUIRED - Every component MUST have these)
// --------------------------------------------------------------------------
// WHY: Catches import errors, missing providers, and basic render issues
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange - Setup data and mocks
// const props = createMockProps()
// Act - Render the component
// render(<ComponentName {...props} />)
// Assert - Verify expected output
// Prefer getByRole for accessibility; it's what users "see"
// expect(screen.getByRole('...')).toBeInTheDocument()
})
it('should render with default props', () => {
// WHY: Verifies component works without optional props
// render(<ComponentName />)
// expect(screen.getByText('...')).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Props Tests (REQUIRED - Every component MUST test prop behavior)
// --------------------------------------------------------------------------
// WHY: Props are the component's API contract. Test them thoroughly.
describe('Props', () => {
it('should apply custom className', () => {
// WHY: Common pattern in Dify - components should merge custom classes
// render(<ComponentName className="custom-class" />)
// expect(screen.getByTestId('component')).toHaveClass('custom-class')
})
it('should use default values for optional props', () => {
// WHY: Verifies TypeScript defaults work at runtime
// render(<ComponentName />)
// expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value')
})
})
// --------------------------------------------------------------------------
// User Interactions (if component has event handlers - on*, handle*)
// --------------------------------------------------------------------------
// WHY: Event handlers are core functionality. Test from user's perspective.
describe('User Interactions', () => {
it('should call onClick when clicked', async () => {
// WHY userEvent over fireEvent?
// - userEvent simulates real user behavior (focus, hover, then click)
// - fireEvent is lower-level, doesn't trigger all browser events
// const user = userEvent.setup()
// const handleClick = vi.fn()
// render(<ComponentName onClick={handleClick} />)
//
// await user.click(screen.getByRole('button'))
//
// expect(handleClick).toHaveBeenCalledTimes(1)
})
it('should call onChange when value changes', async () => {
// const user = userEvent.setup()
// const handleChange = vi.fn()
// render(<ComponentName onChange={handleChange} />)
//
// await user.type(screen.getByRole('textbox'), 'new value')
//
// expect(handleChange).toHaveBeenCalled()
})
})
// --------------------------------------------------------------------------
// State Management (if component uses useState/useReducer)
// --------------------------------------------------------------------------
// WHY: Test state through observable UI changes, not internal state values
describe('State Management', () => {
it('should update state on interaction', async () => {
// WHY test via UI, not state?
// - State is implementation detail; UI is what users see
// - If UI works correctly, state must be correct
// const user = userEvent.setup()
// render(<ComponentName />)
//
// // Initial state - verify what user sees
// expect(screen.getByText('Initial')).toBeInTheDocument()
//
// // Trigger state change via user action
// await user.click(screen.getByRole('button'))
//
// // New state - verify UI updated
// expect(screen.getByText('Updated')).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Async Operations (if component fetches data - useQuery, fetch)
// --------------------------------------------------------------------------
// WHY: Async operations have 3 states users experience: loading, success, error
describe('Async Operations', () => {
it('should show loading state', () => {
// WHY never-resolving promise?
// - Keeps component in loading state for assertion
// - Alternative: use fake timers
// mockedApi.fetchData.mockImplementation(() => new Promise(() => {}))
// render(<ComponentName />)
//
// expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
it('should show data on success', async () => {
// WHY waitFor?
// - Component updates asynchronously after fetch resolves
// - waitFor retries assertion until it passes or times out
// mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
// render(<ComponentName />)
//
// await waitFor(() => {
// expect(screen.getByText('Item 1')).toBeInTheDocument()
// })
})
it('should show error on failure', async () => {
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
// render(<ComponentName />)
//
// await waitFor(() => {
// expect(screen.getByText(/error/i)).toBeInTheDocument()
// })
})
})
// --------------------------------------------------------------------------
// Edge Cases (REQUIRED - Every component MUST handle edge cases)
// --------------------------------------------------------------------------
// WHY: Real-world data is messy. Components must handle:
// - Null/undefined from API failures or optional fields
// - Empty arrays/strings from user clearing data
// - Boundary values (0, MAX_INT, special characters)
describe('Edge Cases', () => {
it('should handle null value', () => {
// WHY test null specifically?
// - API might return null for missing data
// - Prevents "Cannot read property of null" in production
// render(<ComponentName value={null} />)
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle undefined value', () => {
// WHY test undefined separately from null?
// - TypeScript treats them differently
// - Optional props are undefined, not null
// render(<ComponentName value={undefined} />)
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle empty array', () => {
// WHY: Empty state often needs special UI (e.g., "No items yet")
// render(<ComponentName items={[]} />)
// expect(screen.getByText(/empty/i)).toBeInTheDocument()
})
it('should handle empty string', () => {
// WHY: Empty strings are truthy in JS but visually empty
// render(<ComponentName text="" />)
// expect(screen.getByText(/placeholder/i)).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Accessibility (optional but recommended for Dify's enterprise users)
// --------------------------------------------------------------------------
// WHY: Dify has enterprise customers who may require accessibility compliance
describe('Accessibility', () => {
it('should have accessible name', () => {
// WHY getByRole with name?
// - Tests that screen readers can identify the element
// - Enforces proper labeling practices
// render(<ComponentName label="Test Label" />)
// expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument()
})
it('should support keyboard navigation', async () => {
// WHY: Some users can't use a mouse
// const user = userEvent.setup()
// render(<ComponentName />)
//
// await user.tab()
// expect(screen.getByRole('button')).toHaveFocus()
})
})
})

View File

@ -1,207 +0,0 @@
/**
* Test Template for Custom Hooks
*
* Instructions:
* 1. Replace `useHookName` with your hook name
* 2. Update import path
* 3. Add/remove test sections based on hook features
*/
import { renderHook, act, waitFor } from '@testing-library/react'
// import { useHookName } from './use-hook-name'
// ============================================================================
// Mocks
// ============================================================================
// API services (if hook fetches data)
// vi.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = vi.mocked(api)
// ============================================================================
// Test Helpers
// ============================================================================
// Wrapper for hooks that need context
// const createWrapper = (contextValue = {}) => {
// return ({ children }: { children: React.ReactNode }) => (
// <SomeContext.Provider value={contextValue}>
// {children}
// </SomeContext.Provider>
// )
// }
// ============================================================================
// Tests
// ============================================================================
describe('useHookName', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// --------------------------------------------------------------------------
// Initial State
// --------------------------------------------------------------------------
describe('Initial State', () => {
it('should return initial state', () => {
// const { result } = renderHook(() => useHookName())
//
// expect(result.current.value).toBe(initialValue)
// expect(result.current.isLoading).toBe(false)
})
it('should accept initial value from props', () => {
// const { result } = renderHook(() => useHookName({ initialValue: 'custom' }))
//
// expect(result.current.value).toBe('custom')
})
})
// --------------------------------------------------------------------------
// State Updates
// --------------------------------------------------------------------------
describe('State Updates', () => {
it('should update value when setValue is called', () => {
// const { result } = renderHook(() => useHookName())
//
// act(() => {
// result.current.setValue('new value')
// })
//
// expect(result.current.value).toBe('new value')
})
it('should reset to initial value', () => {
// const { result } = renderHook(() => useHookName({ initialValue: 'initial' }))
//
// act(() => {
// result.current.setValue('changed')
// })
// expect(result.current.value).toBe('changed')
//
// act(() => {
// result.current.reset()
// })
// expect(result.current.value).toBe('initial')
})
})
// --------------------------------------------------------------------------
// Async Operations
// --------------------------------------------------------------------------
describe('Async Operations', () => {
it('should fetch data on mount', async () => {
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
//
// const { result } = renderHook(() => useHookName())
//
// // Initially loading
// expect(result.current.isLoading).toBe(true)
//
// // Wait for data
// await waitFor(() => {
// expect(result.current.isLoading).toBe(false)
// })
//
// expect(result.current.data).toEqual({ data: 'test' })
})
it('should handle fetch error', async () => {
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
//
// const { result } = renderHook(() => useHookName())
//
// await waitFor(() => {
// expect(result.current.error).toBeTruthy()
// })
//
// expect(result.current.error?.message).toBe('Network error')
})
it('should refetch when dependency changes', async () => {
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
//
// const { result, rerender } = renderHook(
// ({ id }) => useHookName(id),
// { initialProps: { id: '1' } }
// )
//
// await waitFor(() => {
// expect(mockedApi.fetchData).toHaveBeenCalledWith('1')
// })
//
// rerender({ id: '2' })
//
// await waitFor(() => {
// expect(mockedApi.fetchData).toHaveBeenCalledWith('2')
// })
})
})
// --------------------------------------------------------------------------
// Side Effects
// --------------------------------------------------------------------------
describe('Side Effects', () => {
it('should call callback when value changes', () => {
// const callback = vi.fn()
// const { result } = renderHook(() => useHookName({ onChange: callback }))
//
// act(() => {
// result.current.setValue('new value')
// })
//
// expect(callback).toHaveBeenCalledWith('new value')
})
it('should cleanup on unmount', () => {
// const cleanup = vi.fn()
// vi.spyOn(window, 'addEventListener')
// vi.spyOn(window, 'removeEventListener')
//
// const { unmount } = renderHook(() => useHookName())
//
// expect(window.addEventListener).toHaveBeenCalled()
//
// unmount()
//
// expect(window.removeEventListener).toHaveBeenCalled()
})
})
// --------------------------------------------------------------------------
// Edge Cases
// --------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should handle null input', () => {
// const { result } = renderHook(() => useHookName(null))
//
// expect(result.current.value).toBeNull()
})
it('should handle rapid updates', () => {
// const { result } = renderHook(() => useHookName())
//
// act(() => {
// result.current.setValue('1')
// result.current.setValue('2')
// result.current.setValue('3')
// })
//
// expect(result.current.value).toBe('3')
})
})
// --------------------------------------------------------------------------
// With Context (if hook uses context)
// --------------------------------------------------------------------------
describe('With Context', () => {
it('should use context value', () => {
// const wrapper = createWrapper({ someValue: 'context-value' })
// const { result } = renderHook(() => useHookName(), { wrapper })
//
// expect(result.current.contextValue).toBe('context-value')
})
})
})

View File

@ -1,154 +0,0 @@
/**
* Test Template for Utility Functions
*
* Instructions:
* 1. Replace `utilityFunction` with your function name
* 2. Update import path
* 3. Use test.each for data-driven tests
*/
// import { utilityFunction } from './utility'
// ============================================================================
// Tests
// ============================================================================
describe('utilityFunction', () => {
// --------------------------------------------------------------------------
// Basic Functionality
// --------------------------------------------------------------------------
describe('Basic Functionality', () => {
it('should return expected result for valid input', () => {
// expect(utilityFunction('input')).toBe('expected-output')
})
it('should handle multiple arguments', () => {
// expect(utilityFunction('a', 'b', 'c')).toBe('abc')
})
})
// --------------------------------------------------------------------------
// Data-Driven Tests
// --------------------------------------------------------------------------
describe('Input/Output Mapping', () => {
test.each([
// [input, expected]
['input1', 'output1'],
['input2', 'output2'],
['input3', 'output3'],
])('should return %s for input %s', (input, expected) => {
// expect(utilityFunction(input)).toBe(expected)
})
})
// --------------------------------------------------------------------------
// Edge Cases
// --------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should handle empty string', () => {
// expect(utilityFunction('')).toBe('')
})
it('should handle null', () => {
// expect(utilityFunction(null)).toBe(null)
// or
// expect(() => utilityFunction(null)).toThrow()
})
it('should handle undefined', () => {
// expect(utilityFunction(undefined)).toBe(undefined)
// or
// expect(() => utilityFunction(undefined)).toThrow()
})
it('should handle empty array', () => {
// expect(utilityFunction([])).toEqual([])
})
it('should handle empty object', () => {
// expect(utilityFunction({})).toEqual({})
})
})
// --------------------------------------------------------------------------
// Boundary Conditions
// --------------------------------------------------------------------------
describe('Boundary Conditions', () => {
it('should handle minimum value', () => {
// expect(utilityFunction(0)).toBe(0)
})
it('should handle maximum value', () => {
// expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...)
})
it('should handle negative numbers', () => {
// expect(utilityFunction(-1)).toBe(...)
})
})
// --------------------------------------------------------------------------
// Type Coercion (if applicable)
// --------------------------------------------------------------------------
describe('Type Handling', () => {
it('should handle numeric string', () => {
// expect(utilityFunction('123')).toBe(123)
})
it('should handle boolean', () => {
// expect(utilityFunction(true)).toBe(...)
})
})
// --------------------------------------------------------------------------
// Error Cases
// --------------------------------------------------------------------------
describe('Error Handling', () => {
it('should throw for invalid input', () => {
// expect(() => utilityFunction('invalid')).toThrow('Error message')
})
it('should throw with specific error type', () => {
// expect(() => utilityFunction('invalid')).toThrow(ValidationError)
})
})
// --------------------------------------------------------------------------
// Complex Objects (if applicable)
// --------------------------------------------------------------------------
describe('Object Handling', () => {
it('should preserve object structure', () => {
// const input = { a: 1, b: 2 }
// expect(utilityFunction(input)).toEqual({ a: 1, b: 2 })
})
it('should handle nested objects', () => {
// const input = { nested: { deep: 'value' } }
// expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } })
})
it('should not mutate input', () => {
// const input = { a: 1 }
// const inputCopy = { ...input }
// utilityFunction(input)
// expect(input).toEqual(inputCopy)
})
})
// --------------------------------------------------------------------------
// Array Handling (if applicable)
// --------------------------------------------------------------------------
describe('Array Handling', () => {
it('should process all elements', () => {
// expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6])
})
it('should handle single element array', () => {
// expect(utilityFunction([1])).toEqual([2])
})
it('should preserve order', () => {
// expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b'])
})
})
})

View File

@ -1,345 +0,0 @@
# Async Testing Guide
## Core Async Patterns
### 1. waitFor - Wait for Condition
```typescript
import { render, screen, waitFor } from '@testing-library/react'
it('should load and display data', async () => {
render(<DataComponent />)
// Wait for element to appear
await waitFor(() => {
expect(screen.getByText('Loaded Data')).toBeInTheDocument()
})
})
it('should hide loading spinner after load', async () => {
render(<DataComponent />)
// Wait for element to disappear
await waitFor(() => {
expect(screen.queryByText('Loading...')).not.toBeInTheDocument()
})
})
```
### 2. findBy\* - Async Queries
```typescript
it('should show user name after fetch', async () => {
render(<UserProfile />)
// findBy returns a promise, auto-waits up to 1000ms
const userName = await screen.findByText('John Doe')
expect(userName).toBeInTheDocument()
// findByRole with options
const button = await screen.findByRole('button', { name: /submit/i })
expect(button).toBeEnabled()
})
```
### 3. userEvent for Async Interactions
```typescript
import userEvent from '@testing-library/user-event'
it('should submit form', async () => {
const user = userEvent.setup()
const onSubmit = vi.fn()
render(<Form onSubmit={onSubmit} />)
// userEvent methods are async
await user.type(screen.getByLabelText('Email'), 'test@example.com')
await user.click(screen.getByRole('button', { name: /submit/i }))
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ email: 'test@example.com' })
})
})
```
## Fake Timers
### When to Use Fake Timers
- Testing components with `setTimeout`/`setInterval`
- Testing debounce/throttle behavior
- Testing animations or delayed transitions
- Testing polling or retry logic
### Basic Fake Timer Setup
```typescript
describe('Debounced Search', () => {
beforeEach(() => {
vi.useFakeTimers()
})
afterEach(() => {
vi.useRealTimers()
})
it('should debounce search input', async () => {
const onSearch = vi.fn()
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
// Type in the input
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'query' } })
// Search not called immediately
expect(onSearch).not.toHaveBeenCalled()
// Advance timers
vi.advanceTimersByTime(300)
// Now search is called
expect(onSearch).toHaveBeenCalledWith('query')
})
})
```
### Fake Timers with Async Code
```typescript
it('should retry on failure', async () => {
vi.useFakeTimers()
const fetchData = vi.fn()
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ data: 'success' })
render(<RetryComponent fetchData={fetchData} retryDelayMs={1000} />)
// First call fails
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(1)
})
// Advance timer for retry
vi.advanceTimersByTime(1000)
// Second call succeeds
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(2)
expect(screen.getByText('success')).toBeInTheDocument()
})
vi.useRealTimers()
})
```
### Common Fake Timer Utilities
```typescript
// Run all pending timers
vi.runAllTimers()
// Run only pending timers (not new ones created during execution)
vi.runOnlyPendingTimers()
// Advance by specific time
vi.advanceTimersByTime(1000)
// Get current fake time
Date.now()
// Clear all timers
vi.clearAllTimers()
```
## API Testing Patterns
### Loading → Success → Error States
```typescript
describe('DataFetcher', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should show loading state', () => {
mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) // Never resolves
render(<DataFetcher />)
expect(screen.getByTestId('loading-spinner')).toBeInTheDocument()
})
it('should show data on success', async () => {
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1', 'Item 2'] })
render(<DataFetcher />)
// Use findBy* for multiple async elements (better error messages than waitFor with multiple assertions)
const item1 = await screen.findByText('Item 1')
const item2 = await screen.findByText('Item 2')
expect(item1).toBeInTheDocument()
expect(item2).toBeInTheDocument()
expect(screen.queryByTestId('loading-spinner')).not.toBeInTheDocument()
})
it('should show error on failure', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Failed to fetch'))
render(<DataFetcher />)
await waitFor(() => {
expect(screen.getByText(/failed to fetch/i)).toBeInTheDocument()
})
})
it('should retry on error', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
render(<DataFetcher />)
await waitFor(() => {
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
})
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
fireEvent.click(screen.getByRole('button', { name: /retry/i }))
await waitFor(() => {
expect(screen.getByText('Item 1')).toBeInTheDocument()
})
})
})
```
### Testing Mutations
```typescript
it('should submit form and show success', async () => {
const user = userEvent.setup()
mockedApi.createItem.mockResolvedValue({ id: '1', name: 'New Item' })
render(<CreateItemForm />)
await user.type(screen.getByLabelText('Name'), 'New Item')
await user.click(screen.getByRole('button', { name: /create/i }))
// Button should be disabled during submission
expect(screen.getByRole('button', { name: /creating/i })).toBeDisabled()
await waitFor(() => {
expect(screen.getByText(/created successfully/i)).toBeInTheDocument()
})
expect(mockedApi.createItem).toHaveBeenCalledWith({ name: 'New Item' })
})
```
## useEffect Testing
### Testing Effect Execution
```typescript
it('should fetch data on mount', async () => {
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
render(<ComponentWithEffect fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(1)
})
})
```
### Testing Effect Dependencies
```typescript
it('should refetch when id changes', async () => {
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledWith('1')
})
rerender(<ComponentWithEffect id="2" fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledWith('2')
expect(fetchData).toHaveBeenCalledTimes(2)
})
})
```
### Testing Effect Cleanup
```typescript
it('should cleanup subscription on unmount', () => {
const subscribe = vi.fn()
const unsubscribe = vi.fn()
subscribe.mockReturnValue(unsubscribe)
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
expect(subscribe).toHaveBeenCalledTimes(1)
unmount()
expect(unsubscribe).toHaveBeenCalledTimes(1)
})
```
## Common Async Pitfalls
### ❌ Don't: Forget to await
```typescript
// Bad - test may pass even if assertion fails
it('should load data', () => {
render(<Component />)
waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
// Good - properly awaited
it('should load data', async () => {
render(<Component />)
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
```
### ❌ Don't: Use multiple assertions in single waitFor
```typescript
// Bad - if first assertion fails, won't know about second
await waitFor(() => {
expect(screen.getByText('Title')).toBeInTheDocument()
expect(screen.getByText('Description')).toBeInTheDocument()
})
// Good - separate waitFor or use findBy
const title = await screen.findByText('Title')
const description = await screen.findByText('Description')
expect(title).toBeInTheDocument()
expect(description).toBeInTheDocument()
```
### ❌ Don't: Mix fake timers with real async
```typescript
// Bad - fake timers don't work well with real Promises
vi.useFakeTimers()
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
}) // May timeout!
// Good - use runAllTimers or advanceTimersByTime
vi.useFakeTimers()
render(<Component />)
vi.runAllTimers()
expect(screen.getByText('Data')).toBeInTheDocument()
```

View File

@ -1,205 +0,0 @@
# Test Generation Checklist
Use this checklist when generating or reviewing tests for Dify frontend components.
## Pre-Generation
- [ ] Read the component source code completely
- [ ] Identify component type (component, hook, utility, page)
- [ ] Run `pnpm analyze-component <path>` if available
- [ ] Note complexity score and features detected
- [ ] Check for existing tests in the same directory
- [ ] **Identify ALL files in the directory** that need testing (not just index)
## Testing Strategy
### ⚠️ Incremental Workflow (CRITICAL for Multi-File)
- [ ] **NEVER generate all tests at once** - process one file at a time
- [ ] Order files by complexity: utilities → hooks → simple → complex → integration
- [ ] Create a todo list to track progress before starting
- [ ] For EACH file: write → run test → verify pass → then next
- [ ] **DO NOT proceed** to next file until current one passes
### Path-Level Coverage
- [ ] **Test ALL files** in the assigned directory/path
- [ ] List all components, hooks, utilities that need coverage
- [ ] Decide: single spec file (integration) or multiple spec files (unit)
### Complexity Assessment
- [ ] Run `pnpm analyze-component <path>` for complexity score
- [ ] **Complexity > 50**: Consider refactoring before testing
- [ ] **500+ lines**: Consider splitting before testing
- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure
### Integration vs Mocking
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
- [ ] Import real project components instead of mocking
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
- [ ] Prefer integration testing when using single spec file
## Required Test Sections
### All Components MUST Have
- [ ] **Rendering tests** - Component renders without crashing
- [ ] **Props tests** - Required props, optional props, default values
- [ ] **Edge cases** - null, undefined, empty values, boundaries
### Conditional Sections (Add When Feature Present)
| Feature | Add Tests For |
|---------|---------------|
| `useState` | Initial state, transitions, cleanup |
| `useEffect` | Execution, dependencies, cleanup |
| Event handlers | onClick, onChange, onSubmit, keyboard |
| API calls | Loading, success, error states |
| Routing | Navigation, params, query strings |
| `useCallback`/`useMemo` | Referential equality |
| Context | Provider values, consumer behavior |
| Forms | Validation, submission, error display |
## Code Quality Checklist
### Structure
- [ ] Uses `describe` blocks to group related tests
- [ ] Test names follow `should <behavior> when <condition>` pattern
- [ ] AAA pattern (Arrange-Act-Assert) is clear
- [ ] Comments explain complex test scenarios
### Mocks
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs
### Queries
- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`)
- [ ] Use `queryBy*` for absence assertions
- [ ] Use `findBy*` for async elements
- [ ] `getByTestId` only as last resort
### Async
- [ ] All async tests use `async/await`
- [ ] `waitFor` wraps async assertions
- [ ] Fake timers properly setup/teardown
- [ ] No floating promises
### TypeScript
- [ ] No `any` types without justification
- [ ] Mock data uses actual types from source
- [ ] Factory functions have proper return types
## Coverage Goals (Per File)
For the current file being tested:
- [ ] 100% function coverage
- [ ] 100% statement coverage
- [ ] >95% branch coverage
- [ ] >95% line coverage
## Post-Generation (Per File)
**Run these checks after EACH test file, not just at the end:**
- [ ] Run `pnpm test path/to/file.spec.tsx` - **MUST PASS before next file**
- [ ] Fix any failures immediately
- [ ] Mark file as complete in todo list
- [ ] Only then proceed to next file
### After All Files Complete
- [ ] Run full directory test: `pnpm test path/to/directory/`
- [ ] Check coverage report: `pnpm test:coverage`
- [ ] Run `pnpm lint:fix` on all test files
- [ ] Run `pnpm type-check:tsgo`
## Common Issues to Watch
### False Positives
```typescript
// ❌ Mock doesn't match actual behavior
vi.mock('./Component', () => () => <div>Mocked</div>)
// ✅ Mock matches actual conditional logic
vi.mock('./Component', () => ({ isOpen }: any) =>
isOpen ? <div>Content</div> : null
)
```
### State Leakage
```typescript
// ❌ Shared state not reset
let mockState = false
vi.mock('./useHook', () => () => mockState)
// ✅ Reset in beforeEach
beforeEach(() => {
mockState = false
})
```
### Async Race Conditions
```typescript
// ❌ Not awaited
it('loads data', () => {
render(<Component />)
expect(screen.getByText('Data')).toBeInTheDocument()
})
// ✅ Properly awaited
it('loads data', async () => {
render(<Component />)
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
```
### Missing Edge Cases
Always test these scenarios:
- `null` / `undefined` inputs
- Empty strings / arrays / objects
- Boundary values (0, -1, MAX_INT)
- Error states
- Loading states
- Disabled states
## Quick Commands
```bash
# Run specific test
pnpm test path/to/file.spec.tsx
# Run with coverage
pnpm test:coverage path/to/file.spec.tsx
# Watch mode
pnpm test:watch path/to/file.spec.tsx
# Update snapshots (use sparingly)
pnpm test -u path/to/file.spec.tsx
# Analyze component
pnpm analyze-component path/to/component.tsx
# Review existing test
pnpm analyze-component path/to/component.tsx --review
```

View File

@ -1,449 +0,0 @@
# Common Testing Patterns
## Query Priority
Use queries in this order (most to least preferred):
```typescript
// 1. getByRole - Most recommended (accessibility)
screen.getByRole('button', { name: /submit/i })
screen.getByRole('textbox', { name: /email/i })
screen.getByRole('heading', { level: 1 })
// 2. getByLabelText - Form fields
screen.getByLabelText('Email address')
screen.getByLabelText(/password/i)
// 3. getByPlaceholderText - When no label
screen.getByPlaceholderText('Search...')
// 4. getByText - Non-interactive elements
screen.getByText('Welcome to Dify')
screen.getByText(/loading/i)
// 5. getByDisplayValue - Current input value
screen.getByDisplayValue('current value')
// 6. getByAltText - Images
screen.getByAltText('Company logo')
// 7. getByTitle - Tooltip elements
screen.getByTitle('Close')
// 8. getByTestId - Last resort only!
screen.getByTestId('custom-element')
```
## Event Handling Patterns
### Click Events
```typescript
// Basic click
fireEvent.click(screen.getByRole('button'))
// With userEvent (preferred for realistic interaction)
const user = userEvent.setup()
await user.click(screen.getByRole('button'))
// Double click
await user.dblClick(screen.getByRole('button'))
// Right click
await user.pointer({ keys: '[MouseRight]', target: screen.getByRole('button') })
```
### Form Input
```typescript
const user = userEvent.setup()
// Type in input
await user.type(screen.getByRole('textbox'), 'Hello World')
// Clear and type
await user.clear(screen.getByRole('textbox'))
await user.type(screen.getByRole('textbox'), 'New value')
// Select option
await user.selectOptions(screen.getByRole('combobox'), 'option-value')
// Check checkbox
await user.click(screen.getByRole('checkbox'))
// Upload file
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
```
### Keyboard Events
```typescript
const user = userEvent.setup()
// Press Enter
await user.keyboard('{Enter}')
// Press Escape
await user.keyboard('{Escape}')
// Keyboard shortcut
await user.keyboard('{Control>}a{/Control}') // Ctrl+A
// Tab navigation
await user.tab()
// Arrow keys
await user.keyboard('{ArrowDown}')
await user.keyboard('{ArrowUp}')
```
## Component State Testing
### Testing State Transitions
```typescript
describe('Counter', () => {
it('should increment count', async () => {
const user = userEvent.setup()
render(<Counter initialCount={0} />)
// Initial state
expect(screen.getByText('Count: 0')).toBeInTheDocument()
// Trigger transition
await user.click(screen.getByRole('button', { name: /increment/i }))
// New state
expect(screen.getByText('Count: 1')).toBeInTheDocument()
})
})
```
### Testing Controlled Components
```typescript
describe('ControlledInput', () => {
it('should call onChange with new value', async () => {
const user = userEvent.setup()
const handleChange = vi.fn()
render(<ControlledInput value="" onChange={handleChange} />)
await user.type(screen.getByRole('textbox'), 'a')
expect(handleChange).toHaveBeenCalledWith('a')
})
it('should display controlled value', () => {
render(<ControlledInput value="controlled" onChange={vi.fn()} />)
expect(screen.getByRole('textbox')).toHaveValue('controlled')
})
})
```
## Conditional Rendering Testing
```typescript
describe('ConditionalComponent', () => {
it('should show loading state', () => {
render(<DataDisplay isLoading={true} data={null} />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
expect(screen.queryByTestId('data-content')).not.toBeInTheDocument()
})
it('should show error state', () => {
render(<DataDisplay isLoading={false} data={null} error="Failed to load" />)
expect(screen.getByText(/failed to load/i)).toBeInTheDocument()
})
it('should show data when loaded', () => {
render(<DataDisplay isLoading={false} data={{ name: 'Test' }} />)
expect(screen.getByText('Test')).toBeInTheDocument()
})
it('should show empty state when no data', () => {
render(<DataDisplay isLoading={false} data={[]} />)
expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
})
```
## List Rendering Testing
```typescript
describe('ItemList', () => {
const items = [
{ id: '1', name: 'Item 1' },
{ id: '2', name: 'Item 2' },
{ id: '3', name: 'Item 3' },
]
it('should render all items', () => {
render(<ItemList items={items} />)
expect(screen.getAllByRole('listitem')).toHaveLength(3)
items.forEach(item => {
expect(screen.getByText(item.name)).toBeInTheDocument()
})
})
it('should handle item selection', async () => {
const user = userEvent.setup()
const onSelect = vi.fn()
render(<ItemList items={items} onSelect={onSelect} />)
await user.click(screen.getByText('Item 2'))
expect(onSelect).toHaveBeenCalledWith(items[1])
})
it('should handle empty list', () => {
render(<ItemList items={[]} />)
expect(screen.getByText(/no items/i)).toBeInTheDocument()
})
})
```
## Modal/Dialog Testing
```typescript
describe('Modal', () => {
it('should not render when closed', () => {
render(<Modal isOpen={false} onClose={vi.fn()} />)
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
})
it('should render when open', () => {
render(<Modal isOpen={true} onClose={vi.fn()} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should call onClose when clicking overlay', async () => {
const user = userEvent.setup()
const handleClose = vi.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
await user.click(screen.getByTestId('modal-overlay'))
expect(handleClose).toHaveBeenCalled()
})
it('should call onClose when pressing Escape', async () => {
const user = userEvent.setup()
const handleClose = vi.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
await user.keyboard('{Escape}')
expect(handleClose).toHaveBeenCalled()
})
it('should trap focus inside modal', async () => {
const user = userEvent.setup()
render(
<Modal isOpen={true} onClose={vi.fn()}>
<button>First</button>
<button>Second</button>
</Modal>
)
// Focus should cycle within modal
await user.tab()
expect(screen.getByText('First')).toHaveFocus()
await user.tab()
expect(screen.getByText('Second')).toHaveFocus()
await user.tab()
expect(screen.getByText('First')).toHaveFocus() // Cycles back
})
})
```
## Form Testing
```typescript
describe('LoginForm', () => {
it('should submit valid form', async () => {
const user = userEvent.setup()
const onSubmit = vi.fn()
render(<LoginForm onSubmit={onSubmit} />)
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
await user.type(screen.getByLabelText(/password/i), 'password123')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(onSubmit).toHaveBeenCalledWith({
email: 'test@example.com',
password: 'password123',
})
})
it('should show validation errors', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={vi.fn()} />)
// Submit empty form
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByText(/email is required/i)).toBeInTheDocument()
expect(screen.getByText(/password is required/i)).toBeInTheDocument()
})
it('should validate email format', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={vi.fn()} />)
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByText(/invalid email/i)).toBeInTheDocument()
})
it('should disable submit button while submitting', async () => {
const user = userEvent.setup()
const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
render(<LoginForm onSubmit={onSubmit} />)
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
await user.type(screen.getByLabelText(/password/i), 'password123')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled()
await waitFor(() => {
expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled()
})
})
})
```
## Data-Driven Tests with test.each
```typescript
describe('StatusBadge', () => {
test.each([
['success', 'bg-green-500'],
['warning', 'bg-yellow-500'],
['error', 'bg-red-500'],
['info', 'bg-blue-500'],
])('should apply correct class for %s status', (status, expectedClass) => {
render(<StatusBadge status={status} />)
expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass)
})
test.each([
{ input: null, expected: 'Unknown' },
{ input: undefined, expected: 'Unknown' },
{ input: '', expected: 'Unknown' },
{ input: 'invalid', expected: 'Unknown' },
])('should show "Unknown" for invalid input: $input', ({ input, expected }) => {
render(<StatusBadge status={input} />)
expect(screen.getByText(expected)).toBeInTheDocument()
})
})
```
## Debugging Tips
```typescript
// Print entire DOM
screen.debug()
// Print specific element
screen.debug(screen.getByRole('button'))
// Log testing playground URL
screen.logTestingPlaygroundURL()
// Pretty print DOM
import { prettyDOM } from '@testing-library/react'
console.log(prettyDOM(screen.getByRole('dialog')))
// Check available roles
import { getRoles } from '@testing-library/react'
console.log(getRoles(container))
```
## Common Mistakes to Avoid
### ❌ Don't Use Implementation Details
```typescript
// Bad - testing implementation
expect(component.state.isOpen).toBe(true)
expect(wrapper.find('.internal-class').length).toBe(1)
// Good - testing behavior
expect(screen.getByRole('dialog')).toBeInTheDocument()
```
### ❌ Don't Forget Cleanup
```typescript
// Bad - may leak state between tests
it('test 1', () => {
render(<Component />)
})
// Good - cleanup is automatic with RTL, but reset mocks
beforeEach(() => {
vi.clearAllMocks()
})
```
### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions)
```typescript
// ❌ Bad - hardcoded strings are brittle
expect(screen.getByText('Submit Form')).toBeInTheDocument()
expect(screen.getByText('Loading...')).toBeInTheDocument()
// ✅ Good - role-based queries (most semantic)
expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument()
expect(screen.getByRole('status')).toBeInTheDocument()
// ✅ Good - pattern matching (flexible)
expect(screen.getByText(/submit/i)).toBeInTheDocument()
expect(screen.getByText(/loading/i)).toBeInTheDocument()
// ✅ Good - test behavior, not exact UI text
expect(screen.getByRole('button')).toBeDisabled()
expect(screen.getByRole('alert')).toBeInTheDocument()
```
**Why prefer black-box assertions?**
- Text content may change (i18n, copy updates)
- Role-based queries test accessibility
- Pattern matching is resilient to minor changes
- Tests focus on behavior, not implementation details
### ❌ Don't Assert on Absence Without Query
```typescript
// Bad - throws if not found
expect(screen.getByText('Error')).not.toBeInTheDocument() // Error!
// Good - use queryBy for absence assertions
expect(screen.queryByText('Error')).not.toBeInTheDocument()
```

View File

@ -1,523 +0,0 @@
# Domain-Specific Component Testing
This guide covers testing patterns for Dify's domain-specific components.
## Workflow Components (`workflow/`)
Workflow components handle node configuration, data flow, and graph operations.
### Key Test Areas
1. **Node Configuration**
1. **Data Validation**
1. **Variable Passing**
1. **Edge Connections**
1. **Error Handling**
### Example: Node Configuration Panel
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import NodeConfigPanel from './node-config-panel'
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
// Mock workflow context
vi.mock('@/app/components/workflow/hooks', () => ({
useWorkflowStore: () => mockWorkflowStore,
useNodesInteractions: () => mockNodesInteractions,
}))
let mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: vi.fn(),
}
let mockNodesInteractions = {
handleNodeSelect: vi.fn(),
handleNodeDelete: vi.fn(),
}
describe('NodeConfigPanel', () => {
beforeEach(() => {
vi.clearAllMocks()
mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: vi.fn(),
}
})
describe('Node Configuration', () => {
it('should render node type selector', () => {
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
expect(screen.getByLabelText(/model/i)).toBeInTheDocument()
})
it('should update node config on change', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4')
expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith(
node.id,
expect.objectContaining({ model: 'gpt-4' })
)
})
})
describe('Data Validation', () => {
it('should show error for invalid input', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'code' })
render(<NodeConfigPanel node={node} />)
// Enter invalid code
const codeInput = screen.getByLabelText(/code/i)
await user.clear(codeInput)
await user.type(codeInput, 'invalid syntax {{{')
await waitFor(() => {
expect(screen.getByText(/syntax error/i)).toBeInTheDocument()
})
})
it('should validate required fields', async () => {
const node = createMockNode({ type: 'http', data: { url: '' } })
render(<NodeConfigPanel node={node} />)
fireEvent.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(screen.getByText(/url is required/i)).toBeInTheDocument()
})
})
})
describe('Variable Passing', () => {
it('should display available variables from upstream nodes', () => {
const upstreamNode = createMockNode({
id: 'node-1',
type: 'start',
data: { outputs: [{ name: 'user_input', type: 'string' }] },
})
const currentNode = createMockNode({
id: 'node-2',
type: 'llm',
})
mockWorkflowStore.nodes = [upstreamNode, currentNode]
mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }]
render(<NodeConfigPanel node={currentNode} />)
// Variable selector should show upstream variables
fireEvent.click(screen.getByRole('button', { name: /add variable/i }))
expect(screen.getByText('user_input')).toBeInTheDocument()
})
it('should insert variable into prompt template', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
// Click variable button
await user.click(screen.getByRole('button', { name: /insert variable/i }))
await user.click(screen.getByText('user_input'))
const promptInput = screen.getByLabelText(/prompt/i)
expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}'))
})
})
})
```
## Dataset Components (`dataset/`)
Dataset components handle file uploads, data display, and search/filter operations.
### Key Test Areas
1. **File Upload**
1. **File Type Validation**
1. **Pagination**
1. **Search & Filtering**
1. **Data Format Handling**
### Example: Document Uploader
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import DocumentUploader from './document-uploader'
vi.mock('@/service/datasets', () => ({
uploadDocument: vi.fn(),
parseDocument: vi.fn(),
}))
import * as datasetService from '@/service/datasets'
const mockedService = vi.mocked(datasetService)
describe('DocumentUploader', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('File Upload', () => {
it('should accept valid file types', async () => {
const user = userEvent.setup()
const onUpload = vi.fn()
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
render(<DocumentUploader onUpload={onUpload} />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
const input = screen.getByLabelText(/upload/i)
await user.upload(input, file)
await waitFor(() => {
expect(mockedService.uploadDocument).toHaveBeenCalledWith(
expect.any(FormData)
)
})
})
it('should reject invalid file types', async () => {
const user = userEvent.setup()
render(<DocumentUploader />)
const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' })
const input = screen.getByLabelText(/upload/i)
await user.upload(input, file)
expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument()
expect(mockedService.uploadDocument).not.toHaveBeenCalled()
})
it('should show upload progress', async () => {
const user = userEvent.setup()
// Mock upload with progress
mockedService.uploadDocument.mockImplementation(() => {
return new Promise((resolve) => {
setTimeout(() => resolve({ id: 'doc-1' }), 100)
})
})
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
expect(screen.getByRole('progressbar')).toBeInTheDocument()
await waitFor(() => {
expect(screen.queryByRole('progressbar')).not.toBeInTheDocument()
})
})
})
describe('Error Handling', () => {
it('should handle upload failure', async () => {
const user = userEvent.setup()
mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed'))
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
await waitFor(() => {
expect(screen.getByText(/upload failed/i)).toBeInTheDocument()
})
})
it('should allow retry after failure', async () => {
const user = userEvent.setup()
mockedService.uploadDocument
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ id: 'doc-1' })
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
await waitFor(() => {
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
})
await user.click(screen.getByRole('button', { name: /retry/i }))
await waitFor(() => {
expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument()
})
})
})
})
```
### Example: Document List with Pagination
```typescript
describe('DocumentList', () => {
describe('Pagination', () => {
it('should load first page on mount', async () => {
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '1', name: 'Doc 1' }],
total: 50,
page: 1,
pageSize: 10,
})
render(<DocumentList datasetId="ds-1" />)
await waitFor(() => {
expect(screen.getByText('Doc 1')).toBeInTheDocument()
})
expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 })
})
it('should navigate to next page', async () => {
const user = userEvent.setup()
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '1', name: 'Doc 1' }],
total: 50,
page: 1,
pageSize: 10,
})
render(<DocumentList datasetId="ds-1" />)
await waitFor(() => {
expect(screen.getByText('Doc 1')).toBeInTheDocument()
})
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '11', name: 'Doc 11' }],
total: 50,
page: 2,
pageSize: 10,
})
await user.click(screen.getByRole('button', { name: /next/i }))
await waitFor(() => {
expect(screen.getByText('Doc 11')).toBeInTheDocument()
})
})
})
describe('Search & Filtering', () => {
it('should filter by search query', async () => {
const user = userEvent.setup()
vi.useFakeTimers()
render(<DocumentList datasetId="ds-1" />)
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
// Debounce
vi.advanceTimersByTime(300)
await waitFor(() => {
expect(mockedService.getDocuments).toHaveBeenCalledWith(
'ds-1',
expect.objectContaining({ search: 'test query' })
)
})
vi.useRealTimers()
})
})
})
```
## Configuration Components (`app/configuration/`, `config/`)
Configuration components handle forms, validation, and data persistence.
### Key Test Areas
1. **Form Validation**
1. **Save/Reset**
1. **Required vs Optional Fields**
1. **Configuration Persistence**
1. **Error Feedback**
### Example: App Configuration Form
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import AppConfigForm from './app-config-form'
vi.mock('@/service/apps', () => ({
updateAppConfig: vi.fn(),
getAppConfig: vi.fn(),
}))
import * as appService from '@/service/apps'
const mockedService = vi.mocked(appService)
describe('AppConfigForm', () => {
const defaultConfig = {
name: 'My App',
description: '',
icon: 'default',
openingStatement: '',
}
beforeEach(() => {
vi.clearAllMocks()
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
})
describe('Form Validation', () => {
it('should require app name', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Clear name field
await user.clear(screen.getByLabelText(/name/i))
await user.click(screen.getByRole('button', { name: /save/i }))
expect(screen.getByText(/name is required/i)).toBeInTheDocument()
expect(mockedService.updateAppConfig).not.toHaveBeenCalled()
})
it('should validate name length', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toBeInTheDocument()
})
// Enter very long name
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101))
expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument()
})
it('should allow empty optional fields', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockResolvedValue({ success: true })
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Leave description empty (optional)
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(mockedService.updateAppConfig).toHaveBeenCalled()
})
})
})
describe('Save/Reset Functionality', () => {
it('should save configuration', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockResolvedValue({ success: true })
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'Updated App')
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(mockedService.updateAppConfig).toHaveBeenCalledWith(
'app-1',
expect.objectContaining({ name: 'Updated App' })
)
})
expect(screen.getByText(/saved successfully/i)).toBeInTheDocument()
})
it('should reset to default values', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Make changes
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'Changed Name')
// Reset
await user.click(screen.getByRole('button', { name: /reset/i }))
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
it('should show unsaved changes warning', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Make changes
await user.type(screen.getByLabelText(/name/i), ' Updated')
expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument()
})
})
describe('Error Handling', () => {
it('should show error on save failure', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockRejectedValue(new Error('Server error'))
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(screen.getByText(/failed to save/i)).toBeInTheDocument()
})
})
})
})
```

View File

@ -1,343 +0,0 @@
# Mocking Guide for Dify Frontend Tests
## ⚠️ Important: What NOT to Mock
### DO NOT Mock Base Components
**Never mock components from `@/app/components/base/`** such as:
- `Loading`, `Spinner`
- `Button`, `Input`, `Select`
- `Tooltip`, `Modal`, `Dropdown`
- `Icon`, `Badge`, `Tag`
**Why?**
- Base components will have their own dedicated tests
- Mocking them creates false positives (tests pass but real integration fails)
- Using real components tests actual integration behavior
```typescript
// ❌ WRONG: Don't mock base components
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
// ✅ CORRECT: Import and use real base components
import Loading from '@/app/components/base/loading'
import Button from '@/app/components/base/button'
// They will render normally in tests
```
### What TO Mock
Only mock these categories:
1. **API services** (`@/service/*`) - Network calls
1. **Complex context providers** - When setup is too difficult
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
1. **i18n** - Always mock to return keys
## Mock Placement
| Location | Purpose |
|----------|---------|
| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
| Test file | Test-specific mocks, inline with `vi.mock()` |
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
## Essential Mocks
### 1. i18n (Auto-loaded via Global Mock)
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
**No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock:
```typescript
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
'my.custom.key': 'Custom translation',
}
return translations[key] || key
},
}),
}))
```
### 2. Next.js Router
```typescript
const mockPush = vi.fn()
const mockReplace = vi.fn()
vi.mock('next/navigation', () => ({
useRouter: () => ({
push: mockPush,
replace: mockReplace,
back: vi.fn(),
prefetch: vi.fn(),
}),
usePathname: () => '/current-path',
useSearchParams: () => new URLSearchParams('?key=value'),
}))
describe('Component', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should navigate on click', () => {
render(<Component />)
fireEvent.click(screen.getByRole('button'))
expect(mockPush).toHaveBeenCalledWith('/expected-path')
})
})
```
### 3. Portal Components (with Shared State)
```typescript
// ⚠️ Important: Use shared state for components that depend on each other
let mockPortalOpenState = false
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open, ...props }: any) => {
mockPortalOpenState = open || false // Update shared state
return <div data-testid="portal" data-open={open}>{children}</div>
},
PortalToFollowElemContent: ({ children }: any) => {
// ✅ Matches actual: returns null when portal is closed
if (!mockPortalOpenState) return null
return <div data-testid="portal-content">{children}</div>
},
PortalToFollowElemTrigger: ({ children }: any) => (
<div data-testid="portal-trigger">{children}</div>
),
}))
describe('Component', () => {
beforeEach(() => {
vi.clearAllMocks()
mockPortalOpenState = false // ✅ Reset shared state
})
})
```
### 4. API Service Mocks
```typescript
import * as api from '@/service/api'
vi.mock('@/service/api')
const mockedApi = vi.mocked(api)
describe('Component', () => {
beforeEach(() => {
vi.clearAllMocks()
// Setup default mock implementation
mockedApi.fetchData.mockResolvedValue({ data: [] })
})
it('should show data on success', async () => {
mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] })
render(<Component />)
await waitFor(() => {
expect(screen.getByText('1')).toBeInTheDocument()
})
})
it('should show error on failure', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
render(<Component />)
await waitFor(() => {
expect(screen.getByText(/error/i)).toBeInTheDocument()
})
})
})
```
### 5. HTTP Mocking with Nock
```typescript
import nock from 'nock'
const GITHUB_HOST = 'https://api.github.com'
const GITHUB_PATH = '/repos/owner/repo'
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
return nock(GITHUB_HOST)
.get(GITHUB_PATH)
.delay(delayMs)
.reply(status, body)
}
describe('GithubComponent', () => {
afterEach(() => {
nock.cleanAll()
})
it('should display repo info', async () => {
mockGithubApi(200, { name: 'dify', stars: 1000 })
render(<GithubComponent />)
await waitFor(() => {
expect(screen.getByText('dify')).toBeInTheDocument()
})
})
it('should handle API error', async () => {
mockGithubApi(500, { message: 'Server error' })
render(<GithubComponent />)
await waitFor(() => {
expect(screen.getByText(/error/i)).toBeInTheDocument()
})
})
})
```
### 6. Context Providers
```typescript
import { ProviderContext } from '@/context/provider-context'
import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context'
describe('Component with Context', () => {
it('should render for free plan', () => {
const mockContext = createMockPlan('sandbox')
render(
<ProviderContext.Provider value={mockContext}>
<Component />
</ProviderContext.Provider>
)
expect(screen.getByText('Upgrade')).toBeInTheDocument()
})
it('should render for pro plan', () => {
const mockContext = createMockPlan('professional')
render(
<ProviderContext.Provider value={mockContext}>
<Component />
</ProviderContext.Provider>
)
expect(screen.queryByText('Upgrade')).not.toBeInTheDocument()
})
})
```
### 7. React Query
```typescript
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
const createTestQueryClient = () => new QueryClient({
defaultOptions: {
queries: { retry: false },
mutations: { retry: false },
},
})
const renderWithQueryClient = (ui: React.ReactElement) => {
const queryClient = createTestQueryClient()
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>
)
}
```
## Mock Best Practices
### ✅ DO
1. **Use real base components** - Import from `@/app/components/base/` directly
1. **Use real project components** - Prefer importing over mocking
1. **Reset mocks in `beforeEach`**, not `afterEach`
1. **Match actual component behavior** in mocks (when mocking is necessary)
1. **Use factory functions** for complex mock data
1. **Import actual types** for type safety
1. **Reset shared mock state** in `beforeEach`
### ❌ DON'T
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
1. Don't forget to clean up nock after each test
1. Don't use `any` types in mocks without necessity
### Mock Decision Tree
```
Need to use a component in test?
├─ Is it from @/app/components/base/*?
│ └─ YES → Import real component, DO NOT mock
├─ Is it a project component?
│ └─ YES → Prefer importing real component
│ Only mock if setup is extremely complex
├─ Is it an API service (@/service/*)?
│ └─ YES → Mock it
├─ Is it a third-party lib with side effects?
│ └─ YES → Mock it (next/navigation, external SDKs)
└─ Is it i18n?
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
```
## Factory Function Pattern
```typescript
// __mocks__/data-factories.ts
import type { User, Project } from '@/types'
export const createMockUser = (overrides: Partial<User> = {}): User => ({
id: 'user-1',
name: 'Test User',
email: 'test@example.com',
role: 'member',
createdAt: new Date().toISOString(),
...overrides,
})
export const createMockProject = (overrides: Partial<Project> = {}): Project => ({
id: 'project-1',
name: 'Test Project',
description: 'A test project',
owner: createMockUser(),
members: [],
createdAt: new Date().toISOString(),
...overrides,
})
// Usage in tests
it('should display project owner', () => {
const project = createMockProject({
owner: createMockUser({ name: 'John Doe' }),
})
render(<ProjectCard project={project} />)
expect(screen.getByText('John Doe')).toBeInTheDocument()
})
```

View File

@ -1,269 +0,0 @@
# Testing Workflow Guide
This guide defines the workflow for generating tests, especially for complex components or directories with multiple files.
## Scope Clarification
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
| Scope | Rule |
|-------|------|
| **Single file** | Complete coverage in one generation (100% function, >95% branch) |
| **Multi-file directory** | Process one file at a time, verify each before proceeding |
## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing
When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach.
### Why Incremental?
| Batch Approach (❌) | Incremental Approach (✅) |
|---------------------|---------------------------|
| Generate 5+ tests at once | Generate 1 test at a time |
| Run tests only at the end | Run test immediately after each file |
| Multiple failures compound | Single point of failure, easy to debug |
| Hard to identify root cause | Clear cause-effect relationship |
| Mock issues affect many files | Mock issues caught early |
| Messy git history | Clean, atomic commits possible |
## Single File Workflow
When testing a **single component, hook, or utility**:
```
1. Read source code completely
2. Run `pnpm analyze-component <path>` (if available)
3. Check complexity score and features detected
4. Write the test file
5. Run test: `pnpm test <file>.spec.tsx`
6. Fix any failures
7. Verify coverage meets goals (100% function, >95% branch)
```
## Directory/Multi-File Workflow (MUST FOLLOW)
When testing a **directory or multiple files**, follow this strict workflow:
### Step 1: Analyze and Plan
1. **List all files** that need tests in the directory
1. **Categorize by complexity**:
- 🟢 **Simple**: Utility functions, simple hooks, presentational components
- 🟡 **Medium**: Components with state, effects, or event handlers
- 🔴 **Complex**: Components with API calls, routing, or many dependencies
1. **Order by dependency**: Test dependencies before dependents
1. **Create a todo list** to track progress
### Step 2: Determine Processing Order
Process files in this recommended order:
```
1. Utility functions (simplest, no React)
2. Custom hooks (isolated logic)
3. Simple presentational components (few/no props)
4. Medium complexity components (state, effects)
5. Complex components (API, routing, many deps)
6. Container/index components (integration tests - last)
```
**Rationale**:
- Simpler files help establish mock patterns
- Hooks used by components should be tested first
- Integration tests (index files) depend on child components working
### Step 3: Process Each File Incrementally
**For EACH file in the ordered list:**
```
┌─────────────────────────────────────────────┐
│ 1. Write test file │
│ 2. Run: pnpm test <file>.spec.tsx │
│ 3. If FAIL → Fix immediately, re-run │
│ 4. If PASS → Mark complete in todo list │
│ 5. ONLY THEN proceed to next file │
└─────────────────────────────────────────────┘
```
**DO NOT proceed to the next file until the current one passes.**
### Step 4: Final Verification
After all individual tests pass:
```bash
# Run all tests in the directory together
pnpm test path/to/directory/
# Check coverage
pnpm test:coverage path/to/directory/
```
## Component Complexity Guidelines
Use `pnpm analyze-component <path>` to assess complexity before testing.
### 🔴 Very Complex Components (Complexity > 50)
**Consider refactoring BEFORE testing:**
- Break component into smaller, testable pieces
- Extract complex logic into custom hooks
- Separate container and presentational layers
**If testing as-is:**
- Use integration tests for complex workflows
- Use `test.each()` for data-driven testing
- Multiple `describe` blocks for organization
- Consider testing major sections separately
### 🟡 Medium Complexity (Complexity 30-50)
- Group related tests in `describe` blocks
- Test integration scenarios between internal parts
- Focus on state transitions and side effects
- Use helper functions to reduce test complexity
### 🟢 Simple Components (Complexity < 30)
- Standard test structure
- Focus on props, rendering, and edge cases
- Usually straightforward to test
### 📏 Large Files (500+ lines)
Regardless of complexity score:
- **Strongly consider refactoring** before testing
- If testing as-is, test major sections separately
- Create helper functions for test setup
- May need multiple test files
## Todo List Format
When testing multiple files, use a todo list like this:
```
Testing: path/to/directory/
Ordered by complexity (simple → complex):
☐ utils/helper.ts [utility, simple]
☐ hooks/use-custom-hook.ts [hook, simple]
☐ empty-state.tsx [component, simple]
☐ item-card.tsx [component, medium]
☐ list.tsx [component, complex]
☐ index.tsx [integration]
Progress: 0/6 complete
```
Update status as you complete each:
- ☐ → ⏳ (in progress)
- ⏳ → ✅ (complete and verified)
- ⏳ → ❌ (blocked, needs attention)
## When to Stop and Verify
**Always run tests after:**
- Completing a test file
- Making changes to fix a failure
- Modifying shared mocks
- Updating test utilities or helpers
**Signs you should pause:**
- More than 2 consecutive test failures
- Mock-related errors appearing
- Unclear why a test is failing
- Test passing but coverage unexpectedly low
## Common Pitfalls to Avoid
### ❌ Don't: Generate Everything First
```
# BAD: Writing all files then testing
Write component-a.spec.tsx
Write component-b.spec.tsx
Write component-c.spec.tsx
Write component-d.spec.tsx
Run pnpm test ← Multiple failures, hard to debug
```
### ✅ Do: Verify Each Step
```
# GOOD: Incremental with verification
Write component-a.spec.tsx
Run pnpm test component-a.spec.tsx ✅
Write component-b.spec.tsx
Run pnpm test component-b.spec.tsx ✅
...continue...
```
### ❌ Don't: Skip Verification for "Simple" Components
Even simple components can have:
- Import errors
- Missing mock setup
- Incorrect assumptions about props
**Always verify, regardless of perceived simplicity.**
### ❌ Don't: Continue When Tests Fail
Failing tests compound:
- A mock issue in file A affects files B, C, D
- Fixing A later requires revisiting all dependent tests
- Time wasted on debugging cascading failures
**Fix failures immediately before proceeding.**
## Integration with Claude's Todo Feature
When using Claude for multi-file testing:
1. **Ask Claude to create a todo list** before starting
1. **Request one file at a time** or ensure Claude processes incrementally
1. **Verify each test passes** before asking for the next
1. **Mark todos complete** as you progress
Example prompt:
```
Test all components in `path/to/directory/`.
First, analyze the directory and create a todo list ordered by complexity.
Then, process ONE file at a time, waiting for my confirmation that tests pass
before proceeding to the next.
```
## Summary Checklist
Before starting multi-file testing:
- [ ] Listed all files needing tests
- [ ] Ordered by complexity (simple → complex)
- [ ] Created todo list for tracking
- [ ] Understand dependencies between files
During testing:
- [ ] Processing ONE file at a time
- [ ] Running tests after EACH file
- [ ] Fixing failures BEFORE proceeding
- [ ] Updating todo list progress
After completion:
- [ ] All individual tests pass
- [ ] Full directory test run passes
- [ ] Coverage goals met
- [ ] Todo list shows all complete

View File

@ -1 +0,0 @@
../.claude/skills

View File

@ -1,5 +0,0 @@
[run]
omit =
api/tests/*
api/migrations/*
api/core/rag/datasource/vdb/*

View File

@ -1,5 +1,6 @@
# Cursor Rules for Dify Project
## Automated Test Generation
- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
- When proposing or saving tests, re-read that document and follow every requirement.
- All frontend tests MUST also comply with the `frontend-testing` skill. Treat the skill as a mandatory constraint, not optional guidance.

View File

@ -6,9 +6,6 @@
"context": "..",
"dockerfile": "Dockerfile"
},
"mounts": [
"source=dify-dev-tmp,target=/tmp,type=volume"
],
"features": {
"ghcr.io/devcontainers/features/node:1": {
"nodeGypDependencies": true,
@ -37,13 +34,19 @@
},
"postStartCommand": "./.devcontainer/post_start_command.sh",
"postCreateCommand": "./.devcontainer/post_create_command.sh"
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "python --version",
// Configure tool-specific properties.
// "customizations": {},
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
}
// "remoteUser": "root"
}

View File

@ -1,13 +1,12 @@
#!/bin/bash
WORKSPACE_ROOT=$(pwd)
export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
corepack enable
cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

249
.github/CODEOWNERS vendored
View File

@ -1,249 +0,0 @@
# CODEOWNERS
# This file defines code ownership for the Dify project.
# Each line is a file pattern followed by one or more owners.
# Owners can be @username, @org/team-name, or email addresses.
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
* @crazywoola @laipz8200 @Yeuoly
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola
# Docs
/docs/ @crazywoola
# Backend (default owner, more specific rules below will override)
/api/ @QuantumGhost
# Backend - MCP
/api/core/mcp/ @Nov1c444
/api/core/entities/mcp_provider.py @Nov1c444
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
/api/controllers/mcp/ @Nov1c444
/api/controllers/console/app/mcp_server.py @Nov1c444
/api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine)
/api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
/api/core/workflow/runtime/ @laipz8200 @QuantumGhost
/api/core/workflow/graph/ @laipz8200 @QuantumGhost
/api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
/api/core/workflow/node_events/ @laipz8200 @QuantumGhost
/api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
/api/core/workflow/nodes/agent/ @Nov1c444
/api/core/workflow/nodes/iteration/ @Nov1c444
/api/core/workflow/nodes/loop/ @Nov1c444
/api/core/workflow/nodes/llm/ @Nov1c444
# Backend - RAG (Retrieval Augmented Generation)
/api/core/rag/ @JohnJyong
/api/services/rag_pipeline/ @JohnJyong
/api/services/dataset_service.py @JohnJyong
/api/services/knowledge_service.py @JohnJyong
/api/services/external_knowledge_service.py @JohnJyong
/api/services/hit_testing_service.py @JohnJyong
/api/services/metadata_service.py @JohnJyong
/api/services/vector_service.py @JohnJyong
/api/services/entities/knowledge_entities/ @JohnJyong
/api/services/entities/external_knowledge_entities/ @JohnJyong
/api/controllers/console/datasets/ @JohnJyong
/api/controllers/service_api/dataset/ @JohnJyong
/api/models/dataset.py @JohnJyong
/api/tasks/rag_pipeline/ @JohnJyong
/api/tasks/add_document_to_index_task.py @JohnJyong
/api/tasks/batch_clean_document_task.py @JohnJyong
/api/tasks/clean_document_task.py @JohnJyong
/api/tasks/clean_notion_document_task.py @JohnJyong
/api/tasks/document_indexing_task.py @JohnJyong
/api/tasks/document_indexing_sync_task.py @JohnJyong
/api/tasks/document_indexing_update_task.py @JohnJyong
/api/tasks/duplicate_document_indexing_task.py @JohnJyong
/api/tasks/recover_document_indexing_task.py @JohnJyong
/api/tasks/remove_document_from_index_task.py @JohnJyong
/api/tasks/retry_document_indexing_task.py @JohnJyong
/api/tasks/sync_website_document_indexing_task.py @JohnJyong
/api/tasks/batch_create_segment_to_index_task.py @JohnJyong
/api/tasks/create_segment_to_index_task.py @JohnJyong
/api/tasks/delete_segment_from_index_task.py @JohnJyong
/api/tasks/disable_segment_from_index_task.py @JohnJyong
/api/tasks/disable_segments_from_index_task.py @JohnJyong
/api/tasks/enable_segment_to_index_task.py @JohnJyong
/api/tasks/enable_segments_to_index_task.py @JohnJyong
/api/tasks/clean_dataset_task.py @JohnJyong
/api/tasks/deal_dataset_index_update_task.py @JohnJyong
/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
# Backend - Plugins
/api/core/plugin/ @Mairuis @Yeuoly @Stream29
/api/services/plugin/ @Mairuis @Yeuoly @Stream29
/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
# Backend - Trigger/Schedule/Webhook
/api/controllers/trigger/ @Mairuis @Yeuoly
/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
/api/core/trigger/ @Mairuis @Yeuoly
/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
/api/services/trigger/ @Mairuis @Yeuoly
/api/models/trigger.py @Mairuis @Yeuoly
/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
/api/libs/schedule_utils.py @Mairuis @Yeuoly
/api/services/workflow/scheduler.py @Mairuis @Yeuoly
/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
# Backend - Async Workflow
/api/services/async_workflow_service.py @Mairuis @Yeuoly
/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
# Backend - Billing
/api/services/billing_service.py @hj24 @zyssyz123
/api/controllers/console/billing/ @hj24 @zyssyz123
# Backend - Enterprise
/api/configs/enterprise/ @GarfieldDai @GareArc
/api/services/enterprise/ @GarfieldDai @GareArc
/api/services/feature_service.py @GarfieldDai @GareArc
/api/controllers/console/feature.py @GarfieldDai @GareArc
/api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
/api/migrations/ @snakevash @laipz8200 @MRZHUH
# Backend - Vector DB Middleware
/api/configs/middleware/vdb/* @JohnJyong
# Frontend
/web/ @iamjoel
# Frontend - Web Tests
/.github/workflows/web-tests.yml @iamjoel
# Frontend - App - Orchestration
/web/app/components/workflow/ @iamjoel @zxhlyh
/web/app/components/workflow-app/ @iamjoel @zxhlyh
/web/app/components/app/configuration/ @iamjoel @zxhlyh
/web/app/components/app/app-publisher/ @iamjoel @zxhlyh
# Frontend - WebApp - Chat
/web/app/components/base/chat/ @iamjoel @zxhlyh
# Frontend - WebApp - Completion
/web/app/components/share/text-generation/ @iamjoel @zxhlyh
# Frontend - App - List and Creation
/web/app/components/apps/ @JzoNgKVO @iamjoel
/web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
/web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
/web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
# Frontend - App - API Documentation
/web/app/components/develop/ @JzoNgKVO @iamjoel
# Frontend - App - Logs and Annotations
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
/web/app/components/app/log/ @JzoNgKVO @iamjoel
/web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
/web/app/components/app/annotation/ @JzoNgKVO @iamjoel
# Frontend - App - Monitoring
/web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
/web/app/components/app/overview/ @JzoNgKVO @iamjoel
# Frontend - App - Settings
/web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
# Frontend - RAG - Hit Testing
/web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
# Frontend - RAG - List and Creation
/web/app/components/datasets/list/ @iamjoel @WTW0313
/web/app/components/datasets/create/ @iamjoel @WTW0313
/web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
/web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
/web/app/components/rag-pipeline/ @iamjoel @WTW0313
/web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
/web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
# Frontend - RAG - Documents List
/web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
/web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
# Frontend - RAG - Segments List
/web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
# Frontend - RAG - Settings
/web/app/components/datasets/settings/ @iamjoel @WTW0313
# Frontend - Ecosystem - Plugins
/web/app/components/plugins/ @iamjoel @zhsama
# Frontend - Ecosystem - Tools
/web/app/components/tools/ @iamjoel @Yessenia-d
# Frontend - Ecosystem - MarketPlace
/web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
# Frontend - Login and Registration
/web/app/signin/ @douxc @iamjoel
/web/app/signup/ @douxc @iamjoel
/web/app/reset-password/ @douxc @iamjoel
/web/app/install/ @douxc @iamjoel
/web/app/init/ @douxc @iamjoel
/web/app/forgot-password/ @douxc @iamjoel
/web/app/account/ @douxc @iamjoel
# Frontend - Service Authentication
/web/service/base.ts @douxc @iamjoel
# Frontend - WebApp Authentication and Access Control
/web/app/(shareLayout)/components/ @douxc @iamjoel
/web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
/web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
/web/app/components/app/app-access-control/ @douxc @iamjoel
# Frontend - Explore Page
/web/app/components/explore/ @CodingOnStar @iamjoel
# Frontend - Personal Settings
/web/app/components/header/account-setting/ @CodingOnStar @iamjoel
/web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
# Frontend - Analytics
/web/app/components/base/ga/ @CodingOnStar @iamjoel
# Frontend - Base Components
/web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Utils and Hooks
/web/utils/classnames.ts @iamjoel @zxhlyh
/web/utils/time.ts @iamjoel @zxhlyh
/web/utils/format.ts @iamjoel @zxhlyh
/web/utils/clipboard.ts @iamjoel @zxhlyh
/web/hooks/use-document-title.ts @iamjoel @zxhlyh
# Frontend - Billing and Education
/web/app/components/billing/ @iamjoel @zxhlyh
/web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace
/web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
# Docker
/docker/* @laipz8200

View File

@ -1,6 +1,8 @@
name: "✨ Refactor or Chore"
description: Refactor existing code or perform maintenance chores to improve readability and reliability.
title: "[Refactor/Chore] "
name: "✨ Refactor"
description: Refactor existing code for improved readability and maintainability.
title: "[Chore/Refactor] "
labels:
- refactor
body:
- type: checkboxes
attributes:
@ -9,7 +11,7 @@ body:
options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
@ -23,14 +25,14 @@ body:
id: description
attributes:
label: Description
placeholder: "Describe the refactor or chore you are proposing."
placeholder: "Describe the refactor you are proposing."
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation
placeholder: "Explain why this refactor or chore is necessary."
placeholder: "Explain why this refactor is necessary."
validations:
required: false
- type: textarea

13
.github/ISSUE_TEMPLATE/tracker.yml vendored Normal file
View File

@ -0,0 +1,13 @@
name: "👾 Tracker"
description: For inner usages, please do not use this template.
title: "[Tracker] "
labels:
- tracker
body:
- type: textarea
id: content
attributes:
label: Blockers
placeholder: "- [ ] ..."
validations:
required: true

12
.github/copilot-instructions.md vendored Normal file
View File

@ -0,0 +1,12 @@
# Copilot Instructions
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
Key reminders:
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
- Target >95% line and branch coverage and 100% function/statement coverage.
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.

View File

@ -22,12 +22,12 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -57,7 +57,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@v2
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
docker/docker-compose.middleware.yaml
@ -71,18 +71,18 @@ jobs:
run: |
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Run API Tests
env:
STORAGE_TYPE: opendal
OPENDAL_SCHEME: fs
OPENDAL_FS_ROOT: /tmp/dify-storage
- name: Run Workflow
run: uv run --project api bash dev/pytest/pytest_workflow.sh
- name: Run Tool
run: uv run --project api bash dev/pytest/pytest_tools.sh
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
- name: Run Unit tests
run: |
uv run --project api pytest \
--timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \
api/tests/test_containers_integration_tests \
api/tests/unit_tests
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Coverage Summary
run: |
@ -93,12 +93,5 @@ jobs:
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
{
echo ""
echo "<details><summary>File-level coverage (click to expand)</summary>"
echo ""
echo '```'
uv run --project api coverage report -m
echo '```'
echo "</details>"
} >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY

View File

@ -12,29 +12,12 @@ jobs:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- name: Check Docker Compose inputs
id: docker-compose-changes
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- uses: actions/setup-python@v5
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@v7
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- run: |
cd api
uv sync --dev
@ -52,11 +35,10 @@ jobs:
- name: ast-grep
run: |
# ast-grep exits 1 if no matches are found; allow idempotent runs.
uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
# Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union
@ -74,14 +56,35 @@ jobs:
pattern: $T
fix: $T | None
EOF
uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
uvx mdformat .
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v4
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
- name: Web dependencies
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: oxlint
working-directory: ./web
run: |
pnpx oxlint --fix
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

View File

@ -90,7 +90,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*

View File

@ -13,13 +13,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: "3.12"
@ -63,13 +63,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: "3.12"

View File

@ -27,7 +27,7 @@ jobs:
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: changes
with:
@ -38,7 +38,6 @@ jobs:
- '.github/workflows/api-tests.yml'
web:
- 'web/**'
- '.github/workflows/web-tests.yml'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'

View File

@ -1,21 +0,0 @@
name: Semantic Pull Request
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
jobs:
lint:
name: Validate PR title
permissions:
pull-requests: read
runs-on: ubuntu-latest
steps:
- name: Check title
uses: amannn/action-semantic-pull-request@v6.1.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -19,13 +19,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@v46
with:
files: |
api/**
@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
with:
enable-cache: false
python-version: "3.12"
@ -68,17 +68,15 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@v46
with:
files: |
web/**
.github/workflows/style.yml
files: web/**
- name: Install pnpm
uses: pnpm/action-setup@v4
@ -87,12 +85,12 @@ jobs:
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v6
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
cache-dependency-path: ./web/package.json
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
@ -108,7 +106,37 @@ jobs:
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run type-check:tsgo
run: pnpm run type-check
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- name: Generate Docker Compose
if: steps.changed-files.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- name: Check for changes
if: steps.changed-files.outputs.any_changed == 'true'
run: git diff --exit-code
superlinter:
name: SuperLinter
@ -116,14 +144,14 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@v46
with:
files: |
**.sh

View File

@ -25,12 +25,12 @@ jobs:
working-directory: sdks/nodejs-client
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }}
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node-version }}
cache: ''

View File

@ -1,10 +1,10 @@
name: Translate i18n Files Based on English
name: Check i18n Files and Create PR
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
- 'web/i18n/en-US/*.ts'
permissions:
contents: write
@ -18,7 +18,7 @@ jobs:
run:
working-directory: web
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
@ -28,13 +28,13 @@ jobs:
run: |
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .json)
filename=$(basename "$file" .ts)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
@ -51,11 +51,11 @@ jobs:
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v6
uses: actions/setup-node@v4
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
cache-dependency-path: ./web/package.json
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
@ -65,21 +65,24 @@ jobs:
- name: Generate i18n translations
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
- name: Generate i18n type definitions
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run gen:i18n-types
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore(i18n): update translations based on en-US changes'
title: 'chore(i18n): translate i18n files based on en-US changes'
commit-message: Update i18n files and type definitions based on en-US changes
title: 'chore: translate i18n files and update type definitions'
body: |
This PR was automatically created to update i18n translation files based on changes in en-US locale.
**Triggered by:** ${{ github.sha }}
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
**Changes included:**
- Updated translation files for all locales
branch: chore/automated-i18n-updates-${{ github.sha }}
delete-branch: true
- Regenerated TypeScript type definitions for type safety
branch: chore/automated-i18n-updates

View File

@ -19,19 +19,19 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Free Disk Space
uses: endersonmenezes/free-disk-space@v3
uses: endersonmenezes/free-disk-space@v2
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@ -13,356 +13,46 @@ jobs:
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: web/**
- name: Install pnpm
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
cache-dependency-path: ./web/package.json
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run check:i18n-types
- name: Run tests
run: pnpm test:coverage
- name: Coverage Summary
if: always()
id: coverage-summary
run: |
set -eo pipefail
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
exit 0
fi
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
const fs = require('fs');
const path = require('path');
let libCoverage = null;
try {
libCoverage = require('istanbul-lib-coverage');
} catch (error) {
libCoverage = null;
}
const summaryPath = path.join('coverage', 'coverage-summary.json');
const finalPath = path.join('coverage', 'coverage-final.json');
const hasSummary = fs.existsSync(summaryPath);
const hasFinal = fs.existsSync(finalPath);
if (!hasSummary && !hasFinal) {
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('No coverage data found.');
process.exit(0);
}
const summary = hasSummary
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
: null;
const coverage = hasFinal
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
: null;
const getLineCoverageFromStatements = (statementMap, statementHits) => {
const lineHits = {};
if (!statementMap || !statementHits) {
return lineHits;
}
Object.entries(statementMap).forEach(([key, statement]) => {
const line = statement?.start?.line;
if (!line) {
return;
}
const hits = statementHits[key] ?? 0;
const previous = lineHits[line];
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
});
return lineHits;
};
const getFileCoverage = (entry) => (
libCoverage ? libCoverage.createFileCoverage(entry) : null
);
const getLineHits = (entry, fileCoverage) => {
const lineHits = entry.l ?? {};
if (Object.keys(lineHits).length > 0) {
return lineHits;
}
if (fileCoverage) {
return fileCoverage.getLineCoverage();
}
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
};
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
if (lineHits && Object.keys(lineHits).length > 0) {
return Object.entries(lineHits)
.filter(([, count]) => count === 0)
.map(([line]) => Number(line))
.sort((a, b) => a - b);
}
if (fileCoverage) {
return fileCoverage.getUncoveredLines();
}
return [];
};
const totals = {
lines: { covered: 0, total: 0 },
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
};
const fileSummaries = [];
if (summary) {
const totalEntry = summary.total ?? {};
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
if (totalEntry[key]) {
totals[key].covered = totalEntry[key].covered ?? 0;
totals[key].total = totalEntry[key].total ?? 0;
}
});
Object.entries(summary)
.filter(([file]) => file !== 'total')
.forEach(([file, data]) => {
fileSummaries.push({
file,
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
lines: {
covered: data.lines?.covered ?? 0,
total: data.lines?.total ?? 0,
},
});
});
} else if (coverage) {
Object.entries(coverage).forEach(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
totals.lines.total += lineTotal;
totals.lines.covered += lineCovered;
totals.statements.total += statementTotal;
totals.statements.covered += statementCovered;
totals.branches.total += branchTotal;
totals.branches.covered += branchCovered;
totals.functions.total += functionTotal;
totals.functions.covered += functionCovered;
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
fileSummaries.push({
file,
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
lines: {
covered: lineCovered || statementCovered,
total: lineTotal || statementTotal,
},
});
});
}
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('| Metric | Coverage | Covered / Total |');
console.log('|--------|----------|-----------------|');
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
console.log('');
console.log('<details><summary>File coverage (lowest lines first)</summary>');
console.log('');
console.log('```');
fileSummaries
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
.slice(0, 25)
.forEach(({ file, pct, lines }) => {
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
});
console.log('```');
console.log('</details>');
if (coverage) {
const pctValue = (covered, tot) => {
if (tot === 0) {
return '0';
}
return ((covered / tot) * 100)
.toFixed(2)
.replace(/\.?0+$/, '');
};
const formatLineRanges = (lines) => {
if (lines.length === 0) {
return '';
}
const ranges = [];
let start = lines[0];
let end = lines[0];
for (let i = 1; i < lines.length; i += 1) {
const current = lines[i];
if (current === end + 1) {
end = current;
continue;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
start = current;
end = current;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
return ranges.join(',');
};
const tableTotals = {
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
lines: { covered: 0, total: 0 },
};
const tableRows = Object.entries(coverage)
.map(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
tableTotals.lines.total += lineTotal;
tableTotals.lines.covered += lineCovered;
tableTotals.statements.total += statementTotal;
tableTotals.statements.covered += statementCovered;
tableTotals.branches.total += branchTotal;
tableTotals.branches.covered += branchCovered;
tableTotals.functions.total += functionTotal;
tableTotals.functions.covered += functionCovered;
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
const filePath = entry.path ?? file;
const relativePath = path.isAbsolute(filePath)
? path.relative(process.cwd(), filePath)
: filePath;
return {
file: relativePath || file,
statements: pctValue(statementCovered, statementTotal),
branches: pctValue(branchCovered, branchTotal),
functions: pctValue(functionCovered, functionTotal),
lines: pctValue(lineCovered, lineTotal),
uncovered: formatLineRanges(uncoveredLines),
};
})
.sort((a, b) => a.file.localeCompare(b.file));
const columns = [
{ key: 'file', header: 'File', align: 'left' },
{ key: 'statements', header: '% Stmts', align: 'right' },
{ key: 'branches', header: '% Branch', align: 'right' },
{ key: 'functions', header: '% Funcs', align: 'right' },
{ key: 'lines', header: '% Lines', align: 'right' },
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
];
const allFilesRow = {
file: 'All files',
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
uncovered: '',
};
const rowsForOutput = [allFilesRow, ...tableRows];
const formatRow = (row) => `| ${columns
.map(({ key }) => String(row[key] ?? ''))
.join(' | ')} |`;
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
const dividerRow = `| ${columns
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
.join(' | ')} |`;
console.log('');
console.log('<details><summary>Vitest coverage table</summary>');
console.log('');
console.log(headerRow);
console.log(dividerRow);
rowsForOutput.forEach((row) => console.log(formatRow(row)));
console.log('</details>');
}
NODE
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v6
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm test

13
.gitignore vendored
View File

@ -139,6 +139,7 @@ pyrightconfig.json
.idea/'
.DS_Store
web/.vscode/settings.json
# Intellij IDEA Files
.idea/*
@ -188,14 +189,12 @@ docker/volumes/matrixone/*
docker/volumes/mysql/*
docker/volumes/seekdb/*
!docker/volumes/oceanbase/init.d
docker/volumes/iris/*
docker/nginx/conf.d/default.conf
docker/nginx/ssl/*
!docker/nginx/ssl/.gitkeep
docker/middleware.env
docker/docker-compose.override.yaml
docker/env-backup/*
sdks/python-client/build
sdks/python-client/dist
@ -205,6 +204,7 @@ sdks/python-client/dify_client.egg-info
!.vscode/launch.json.template
!.vscode/README.md
api/.vscode
web/.vscode
# vscode Code History Extension
.history
@ -219,6 +219,15 @@ plugins.jsonl
# mise
mise.toml
# Next.js build output
.next/
# PWA generated files
web/public/sw.js
web/public/sw.js.map
web/public/workbox-*.js
web/public/workbox-*.js.map
web/public/fallback-*.js
# AI Assistant
.roo/

34
.mcp.json Normal file
View File

@ -0,0 +1,34 @@
{
"mcpServers": {
"context7": {
"type": "http",
"url": "https://mcp.context7.com/mcp"
},
"sequential-thinking": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"env": {}
},
"github": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
}
},
"fetch": {
"type": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"env": {}
},
"playwright": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@playwright/mcp@latest"],
"env": {}
}
}
}

1
.nvmrc
View File

@ -1 +0,0 @@
22.11.0

View File

@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
"--loglevel",
"INFO"
],

View File

@ -0,0 +1,5 @@
# Windsurf Testing Rules
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
- Honor every requirement in that document when generating or accepting tests.
- When proposing or saving tests, re-read that document and follow every requirement.

View File

@ -24,8 +24,8 @@ The codebase is split into:
```bash
cd web
pnpm lint
pnpm lint:fix
pnpm type-check:tsgo
pnpm test
```
@ -39,7 +39,7 @@ pnpm test
## Language Style
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types.
## General Practices

View File

@ -36,12 +36,6 @@
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Health Score" src="https://insights.linuxfoundation.org/api/badge/health-score?project=langgenius-dify"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Contributors" src="https://insights.linuxfoundation.org/api/badge/contributors?project=langgenius-dify"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Active Contributors" src="https://insights.linuxfoundation.org/api/badge/active-contributors?project=langgenius-dify"></a>
</p>
<p align="center">
@ -139,19 +133,6 @@ Star Dify on GitHub and be instantly notified of new releases.
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
#### Customizing Suggested Questions
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
```bash
# In your .env file
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
SUGGESTED_QUESTIONS_MAX_TOKENS=512
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
```
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
### Metrics Monitoring with Grafana
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.

View File

@ -116,7 +116,6 @@ ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
@ -128,14 +127,12 @@ TENCENT_COS_SECRET_KEY=your-secret-key
TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
TENCENT_COS_CUSTOM_DOMAIN=your-custom-domain
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
HUAWEI_OBS_PATH_STYLE=false
# Baidu OBS Storage Configuration
BAIDU_OBS_BUCKET_NAME=your-bucket-name
@ -543,28 +540,8 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0
# Aliyun SLS Logstore Configuration
# Aliyun Access Key ID
ALIYUN_SLS_ACCESS_KEY_ID=
# Aliyun Access Key Secret
ALIYUN_SLS_ACCESS_KEY_SECRET=
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
ALIYUN_SLS_ENDPOINT=
# Aliyun SLS Region (e.g., cn-hangzhou)
ALIYUN_SLS_REGION=
# Aliyun SLS Project Name
ALIYUN_SLS_PROJECT_NAME=
# Number of days to retain workflow run logs (default: 365 days 3650 for permanent storage)
ALIYUN_SLS_LOGSTORE_TTL=365
# Enable dual-write to both SLS LogStore and SQL database (default: false)
LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
@ -655,45 +632,8 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Suggested Questions After Answer Configuration
# These environment variables allow customization of the suggested questions feature
#
# Custom prompt for generating suggested questions (optional)
# If not set, uses the default prompt that generates 3 questions under 20 characters each
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
# SUGGESTED_QUESTIONS_PROMPT=
# Maximum number of tokens for suggested questions generation (default: 256)
# Adjust this value for longer questions or more questions
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
# Temperature for suggested questions generation (default: 0.0)
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
# SUGGESTED_QUESTIONS_TEMPERATURE=0
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0
# Multimodal knowledgebase limit
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
IMAGE_FILE_BATCH_LIMIT=10
# Maximum allowed CSV file size for annotation import in megabytes
ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2
#Maximum number of annotation records allowed in a single import
ANNOTATION_IMPORT_MAX_RECORDS=10000
# Minimum number of annotation records required in a single import
ANNOTATION_IMPORT_MIN_RECORDS=1
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@ -16,7 +16,6 @@ layers =
graph
nodes
node_events
runtime
entities
containers =
core.workflow

View File

@ -36,20 +36,17 @@ select = [
"UP", # pyupgrade rules
"W191", # tab-indentation
"W605", # invalid-escape-sequence
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum,
"S110", # disallow the try-except-pass pattern.
# security related linting rules
# RCE proctection (sort of)
"S102", # exec-builtin, disallow use of `exec`
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage,
"S311", # suspicious-non-cryptographic-random-usage
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum
]
ignore = [
@ -94,16 +91,18 @@ ignore = [
"configs/*" = [
"N802", # invalid-function-name
]
"core/model_runtime/callbacks/base_callback.py" = ["T201"]
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
"core/model_runtime/callbacks/base_callback.py" = [
"T201",
]
"core/workflow/callbacks/workflow_logging_callback.py" = [
"T201",
]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
"T201", # allow print in tests
]
[lint.pyflakes]

View File

@ -48,12 +48,6 @@ ENV PYTHONIOENCODING=utf-8
WORKDIR /app/api
# Create non-root user
ARG dify_uid=1001
RUN groupadd -r -g ${dify_uid} dify && \
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
chown -R dify:dify /app
RUN \
apt-get update \
# Install dependencies
@ -75,7 +69,7 @@ RUN \
# Copy Python environment and packages
ENV VIRTUAL_ENV=/app/api/.venv
COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
@ -84,20 +78,24 @@ RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" \
&& chown -R dify:dify ${TIKTOKEN_CACHE_DIR}
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
# Copy source code
COPY --chown=dify:dify . /app/api/
COPY . /app/api/
# Prepare entrypoint script
COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh /entrypoint.sh
# Copy entrypoint
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
# Create non-root user and set permissions
RUN groupadd -r -g 1001 dify && \
useradd -r -u 1001 -g 1001 -s /bin/bash dify && \
mkdir -p /home/dify && \
chown -R 1001:1001 /app /home/dify ${TIKTOKEN_CACHE_DIR} /entrypoint.sh
ARG COMMIT_SHA
ENV COMMIT_SHA=${COMMIT_SHA}
ENV NLTK_DATA=/usr/local/share/nltk_data
USER dify
USER 1001
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]

View File

@ -84,7 +84,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -1,8 +1,6 @@
import logging
import time
from opentelemetry.trace import get_current_span
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp
@ -28,25 +26,8 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
@dify_app.after_request
def add_trace_id_header(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_id_header
return dify_app
@ -70,12 +51,10 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
ext_forward_refs,
ext_hosting_provider,
ext_import_modules,
ext_logging,
ext_login,
ext_logstore,
ext_mail,
ext_migrate,
ext_orjson,
@ -84,7 +63,6 @@ def initialize_extensions(app: DifyApp):
ext_redis,
ext_request_logging,
ext_sentry,
ext_session_factory,
ext_set_secretkey,
ext_storage,
ext_timezone,
@ -97,7 +75,6 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
ext_import_modules,
ext_orjson,
ext_forward_refs,
ext_set_secretkey,
ext_compress,
ext_code_based_extension,
@ -106,7 +83,6 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
ext_logstore, # Initialize logstore after storage, before celery
ext_celery,
ext_login,
ext_mail,
@ -117,7 +93,6 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_otel,
ext_request_logging,
ext_session_factory,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]

View File

@ -1139,7 +1139,6 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
except Exception as e:
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
return
all_files_on_storage = []
for storage_path in storage_paths:

View File

@ -73,10 +73,6 @@ class AppExecutionConfig(BaseSettings):
description="Maximum allowed execution time for the application in seconds",
default=1200,
)
APP_DEFAULT_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Default number of concurrent active requests per app (0 for unlimited)",
default=0,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Maximum number of concurrent active requests per app (0 for unlimited)",
default=0,
@ -218,7 +214,7 @@ class PluginConfig(BaseSettings):
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=600.0,
default=300.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
@ -360,57 +356,6 @@ class FileUploadConfig(BaseSettings):
default=10,
)
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a image batch upload operation",
default=10,
)
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a single chunk attachment",
default=10,
)
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed image file size for attachments in megabytes",
default=2,
)
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
description="Timeout for downloading image attachments in seconds",
default=60,
)
# Annotation Import Security Configurations
ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed CSV file size for annotation import in megabytes",
default=2,
)
ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field(
description="Maximum number of annotation records allowed in a single import",
default=10000,
)
ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field(
description="Minimum number of annotation records required in a single import",
default=1,
)
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of annotation import requests per minute per tenant",
default=5,
)
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field(
description="Maximum number of annotation import requests per hour per tenant",
default=20,
)
ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field(
description="Maximum number of concurrent annotation import tasks per tenant",
default=2,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=(
"Comma-separated list of file extensions that are blocked from upload. "
@ -604,10 +549,7 @@ class LoggingConfig(BaseSettings):
LOG_FORMAT: str = Field(
description="Format string for log messages",
default=(
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
),
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
)
LOG_DATEFORMAT: str | None = Field(
@ -1270,21 +1212,6 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
)
class SandboxExpiredRecordsCleanConfig(BaseSettings):
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
description="Graceful period in days for sandbox records clean after subscription expiration",
default=21,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1310,7 +1237,6 @@ class FeatureConfig(
PositionConfig,
RagEtlConfig,
RepositoryConfig,
SandboxExpiredRecordsCleanConfig,
SecurityConfig,
TenantIsolatedTaskQueueConfig,
ToolConfig,

View File

@ -26,7 +26,6 @@ from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.iris_config import IrisVectorConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.matrixone_config import MatrixoneConfig
from .vdb.milvus_config import MilvusConfig
@ -107,7 +106,7 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql",
)
@ -337,7 +336,6 @@ class MiddlewareConfig(
ChromaConfig,
ClickzettaConfig,
HuaweiCloudConfig,
IrisVectorConfig,
MilvusConfig,
AlibabaCloudMySQLConfig,
MyScaleConfig,

View File

@ -41,8 +41,3 @@ class AliyunOSSStorageConfig(BaseSettings):
description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
default=None,
)
ALIYUN_CLOUDBOX_ID: str | None = Field(
description="Cloudbox id for aliyun cloudbox service",
default=None,
)

View File

@ -26,8 +26,3 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
default=None,
)
HUAWEI_OBS_PATH_STYLE: bool = Field(
description="Flag to indicate whether to use path-style URLs for OBS requests",
default=False,
)

View File

@ -31,8 +31,3 @@ class TencentCloudCOSStorageConfig(BaseSettings):
description="Protocol scheme for COS requests: 'https' (recommended) or 'http'",
default=None,
)
TENCENT_COS_CUSTOM_DOMAIN: str | None = Field(
description="Tencent Cloud COS custom domain setting",
default=None,
)

View File

@ -1,91 +0,0 @@
"""Configuration for InterSystems IRIS vector database."""
from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings
class IrisVectorConfig(BaseSettings):
"""Configuration settings for IRIS vector database connection and pooling."""
IRIS_HOST: str | None = Field(
description="Hostname or IP address of the IRIS server.",
default="localhost",
)
IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field(
description="Port number for IRIS connection.",
default=1972,
)
IRIS_USER: str | None = Field(
description="Username for IRIS authentication.",
default="_SYSTEM",
)
IRIS_PASSWORD: str | None = Field(
description="Password for IRIS authentication.",
default="Dify@1234",
)
IRIS_SCHEMA: str | None = Field(
description="Schema name for IRIS tables.",
default="dify",
)
IRIS_DATABASE: str | None = Field(
description="Database namespace for IRIS connection.",
default="USER",
)
IRIS_CONNECTION_URL: str | None = Field(
description="Full connection URL for IRIS (overrides individual fields if provided).",
default=None,
)
IRIS_MIN_CONNECTION: PositiveInt = Field(
description="Minimum number of connections in the pool.",
default=1,
)
IRIS_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the pool.",
default=3,
)
IRIS_TEXT_INDEX: bool = Field(
description="Enable full-text search index using %iFind.Index.Basic.",
default=True,
)
IRIS_TEXT_INDEX_LANGUAGE: str = Field(
description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').",
default="en",
)
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""Validate IRIS configuration values.
Args:
values: Configuration dictionary
Returns:
Validated configuration dictionary
Raises:
ValueError: If required fields are missing or pool settings are invalid
"""
# Only validate required fields if IRIS is being used as the vector store
# This allows the config to be loaded even when IRIS is not in use
# vector_store = os.environ.get("VECTOR_STORE", "")
# We rely on Pydantic defaults for required fields if they are missing from env.
# Strict existence check is removed to allow defaults to work.
min_conn = values.get("IRIS_MIN_CONNECTION", 1)
max_conn = values.get("IRIS_MAX_CONNECTION", 3)
if min_conn > max_conn:
raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION")
return values

View File

@ -20,7 +20,6 @@ language_timezone_mapping = {
"sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
"ar-TN": "Africa/Tunis",
}
languages = list(language_timezone_mapping.keys())

View File

@ -1,57 +0,0 @@
import os
from email.message import Message
from urllib.parse import quote
from flask import Response
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
HTML_EXTENSIONS = frozenset({"html", "htm"})
def _normalize_mime_type(mime_type: str | None) -> str:
if not mime_type:
return ""
message = Message()
message["Content-Type"] = mime_type
return message.get_content_type().strip().lower()
def _is_html_extension(extension: str | None) -> bool:
if not extension:
return False
return extension.lstrip(".").lower() in HTML_EXTENSIONS
def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
normalized_mime_type = _normalize_mime_type(mime_type)
if normalized_mime_type in HTML_MIME_TYPES:
return True
if _is_html_extension(extension):
return True
if filename:
return _is_html_extension(os.path.splitext(filename)[1])
return False
def enforce_download_for_html(
response: Response,
*,
mime_type: str | None,
filename: str | None,
extension: str | None = None,
) -> bool:
if not is_html_content(mime_type, filename, extension):
return False
if filename:
encoded_filename = quote(filename)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
else:
response.headers["Content-Disposition"] = "attachment"
response.headers["Content-Type"] = "application/octet-stream"
response.headers["X-Content-Type-Options"] = "nosniff"
return True

View File

@ -1,26 +0,0 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
from flask_restx import Namespace
from pydantic import BaseModel
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a single BaseModel with a namespace for Swagger documentation."""
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
"""Register multiple BaseModels with a namespace."""
for model in models:
register_schema_model(namespace, model)
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"register_schema_model",
"register_schema_models",
]

View File

@ -3,47 +3,21 @@ from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InsertExploreAppPayload(BaseModel):
app_id: str = Field(...)
desc: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]):
@wraps(view)
@ -66,34 +40,59 @@ def admin_required(view: Callable[P, R]):
class InsertExploreAppListApi(Resource):
@console_ns.doc("insert_explore_app")
@console_ns.doc(description="Insert or update an app in the explore list")
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
@console_ns.expect(
console_ns.model(
"InsertExploreAppRequest",
{
"app_id": fields.String(required=True, description="Application ID"),
"desc": fields.String(description="App description"),
"copyright": fields.String(description="Copyright information"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
},
)
)
@console_ns.response(200, "App updated successfully")
@console_ns.response(201, "App inserted successfully")
@console_ns.response(404, "App not found")
@only_edition_cloud
@admin_required
def post(self):
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
.add_argument("desc", type=str, location="json")
.add_argument("copyright", type=str, location="json")
.add_argument("privacy_policy", type=str, location="json")
.add_argument("custom_disclaimer", type=str, location="json")
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
if not app:
raise NotFound(f"App '{payload.app_id}' is not found")
raise NotFound(f"App '{args['app_id']}' is not found")
site = app.site
if not site:
desc = payload.desc or ""
copy_right = payload.copyright or ""
privacy_policy = payload.privacy_policy or ""
custom_disclaimer = payload.custom_disclaimer or ""
desc = args["desc"] or ""
copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] or ""
else:
desc = site.description or payload.desc or ""
copy_right = site.copyright or payload.copyright or ""
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
desc = site.description or args["desc"] or ""
copy_right = site.copyright or args["copyright"] or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
with session_factory.create_session() as session:
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()
if not recommended_app:
@ -103,9 +102,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
language=payload.language,
category=payload.category,
position=payload.position,
language=args["language"],
category=args["category"],
position=args["position"],
)
db.session.add(recommended_app)
@ -119,9 +118,9 @@ class InsertExploreAppListApi(Resource):
recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = payload.language
recommended_app.category = payload.category
recommended_app.position = payload.position
recommended_app.language = args["language"]
recommended_app.category = args["category"]
recommended_app.position = args["position"]
app.is_public = True
@ -139,7 +138,7 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud
@admin_required
def delete(self, app_id):
with session_factory.create_session() as session:
with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
@ -147,13 +146,13 @@ class InsertExploreAppApi(Resource):
if not recommended_app:
return {"result": "success"}, 204
with session_factory.create_session() as session:
with Session(db.engine) as session:
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
with session_factory.create_session() as session:
with Session(db.engine) as session:
installed_apps = (
session.execute(
select(InstalledApp).where(

View File

@ -1,23 +1,16 @@
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, reqparse
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateQuery(BaseModel):
app_mode: str = Field(..., description="Application mode")
model_mode: str = Field(..., description="Model mode")
has_context: str = Field(default="true", description="Whether has context")
model_name: str = Field(..., description="Model name")
console_ns.schema_model(
AdvancedPromptTemplateQuery.__name__,
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
parser = (
reqparse.RequestParser()
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
)
@ -25,7 +18,7 @@ console_ns.schema_model(
class AdvancedPromptTemplateList(Resource):
@console_ns.doc("get_advanced_prompt_templates")
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
)
@ -34,6 +27,6 @@ class AdvancedPromptTemplateList(Resource):
@login_required
@account_initialization_required
def get(self):
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
return AdvancedPromptTemplateService.get_prompt(args)

View File

@ -1,6 +1,4 @@
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, reqparse
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -10,21 +8,10 @@ from libs.login import login_required
from models.model import AppMode
from services.agent_service import AgentService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AgentLogQuery(BaseModel):
message_id: str = Field(..., description="Message UUID")
conversation_id: str = Field(..., description="Conversation UUID")
@field_validator("message_id", "conversation_id")
@classmethod
def validate_uuid(cls, value: str) -> str:
return uuid_value(value)
console_ns.schema_model(
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
)
@ -33,7 +20,7 @@ class AgentLogApi(Resource):
@console_ns.doc("get_agent_logs")
@console_ns.doc(description="Get agent execution logs for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AgentLogQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
)
@ -44,6 +31,6 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model):
"""Get agent logs"""
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])

View File

@ -1,15 +1,12 @@
from typing import Any, Literal
from typing import Literal
from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
annotation_import_concurrency_limit,
annotation_import_rate_limit,
cloud_edition_billing_resource_check,
edit_permission_required,
setup_required,
@ -24,79 +21,22 @@ from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import AppAnnotationService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AnnotationReplyPayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold for annotation matching")
embedding_provider_name: str = Field(..., description="Embedding provider name")
embedding_model_name: str = Field(..., description="Embedding model name")
class AnnotationSettingUpdatePayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold")
class AnnotationListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, description="Page size")
keyword: str = Field(default="", description="Search keyword")
class CreateAnnotationPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
question: str | None = Field(default=None, description="Question text")
answer: str | None = Field(default=None, description="Answer text")
content: str | None = Field(default=None, description="Content text")
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class UpdateAnnotationPayload(BaseModel):
question: str | None = None
answer: str | None = None
content: str | None = None
annotation_reply: dict[str, Any] | None = None
class AnnotationReplyStatusQuery(BaseModel):
action: Literal["enable", "disable"]
class AnnotationFilePayload(BaseModel):
message_id: str = Field(..., description="Message ID")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
def reg(model: type[BaseModel]) -> None:
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AnnotationReplyPayload)
reg(AnnotationSettingUpdatePayload)
reg(AnnotationListQuery)
reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource):
@console_ns.doc("annotation_reply_action")
@console_ns.doc(description="Enable or disable annotation reply for an app")
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
@console_ns.expect(
console_ns.model(
"AnnotationReplyActionRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
},
)
)
@console_ns.response(200, "Action completed successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -106,9 +46,15 @@ class AnnotationReplyActionApi(Resource):
@edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id)
args = AnnotationReplyPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("score_threshold", required=True, type=float, location="json")
.add_argument("embedding_provider_name", required=True, type=str, location="json")
.add_argument("embedding_model_name", required=True, type=str, location="json")
)
args = parser.parse_args()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@ -136,7 +82,16 @@ class AppAnnotationSettingUpdateApi(Resource):
@console_ns.doc("update_annotation_setting")
@console_ns.doc(description="Update annotation settings for an app")
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
@console_ns.expect(
console_ns.model(
"AnnotationSettingUpdateRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
"embedding_model_name": fields.String(required=True, description="Embedding model"),
},
)
)
@console_ns.response(200, "Settings updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -147,9 +102,10 @@ class AppAnnotationSettingUpdateApi(Resource):
app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id)
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
return result, 200
@ -186,7 +142,12 @@ class AnnotationApi(Resource):
@console_ns.doc("list_annotations")
@console_ns.doc(description="Get annotations for an app with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
)
@console_ns.response(200, "Annotations retrieved successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -194,10 +155,9 @@ class AnnotationApi(Resource):
@account_initialization_required
@edit_permission_required
def get(self, app_id):
args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args.page
limit = args.limit
keyword = args.keyword
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
@ -213,7 +173,18 @@ class AnnotationApi(Resource):
@console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
@console_ns.expect(
console_ns.model(
"CreateAnnotationRequest",
{
"message_id": fields.String(description="Message ID (optional)"),
"question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
)
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -224,9 +195,16 @@ class AnnotationApi(Resource):
@edit_permission_required
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=False, type=uuid_value, location="json")
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
return annotation
@setup_required
@ -259,7 +237,7 @@ class AnnotationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
class AnnotationExportApi(Resource):
@console_ns.doc("export_annotations")
@console_ns.doc(description="Export all annotations for an app with CSV injection protection")
@console_ns.doc(description="Export all annotations for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(
200,
@ -274,14 +252,15 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response_data = {"data": marshal(annotation_list, annotation_fields)}
response = {"data": marshal(annotation_list, annotation_fields)}
return response, 200
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
response.headers["Content-Type"] = "application/json; charset=utf-8"
response.headers["X-Content-Type-Options"] = "nosniff"
return response
parser = (
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
@ -292,7 +271,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@console_ns.expect(parser)
@setup_required
@login_required
@account_initialization_required
@ -302,10 +281,8 @@ class AnnotationUpdateDeleteApi(Resource):
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation
@setup_required
@ -322,25 +299,18 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
class AnnotationBatchImportApi(Resource):
@console_ns.doc("batch_import_annotations")
@console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks")
@console_ns.doc(description="Batch import annotations from CSV file")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Batch import started successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "No file uploaded or too many files")
@console_ns.response(413, "File too large")
@console_ns.response(429, "Too many requests or concurrent imports")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@annotation_import_rate_limit
@annotation_import_concurrency_limit
@edit_permission_required
def post(self, app_id):
from configs import dify_config
app_id = str(app_id)
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -350,27 +320,9 @@ class AnnotationBatchImportApi(Resource):
# get file from request
file = request.files["file"]
# check file type
if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
# Check file size before processing
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
abort(
413,
f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
f"Please reduce the file size and try again.",
)
if file_size == 0:
raise ValueError("The uploaded file is empty")
return AppAnnotationService.batch_import_app_annotations(app_id, file)

View File

@ -1,12 +1,9 @@
import uuid
from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from werkzeug.exceptions import BadRequest, abort
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -31,6 +28,7 @@ from fields.app_fields import (
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
@ -38,116 +36,6 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
if not value:
return None
if isinstance(value, str):
items = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(value, list):
items = [str(item).strip() for item in value if item and str(item).strip()]
else:
raise TypeError("Unsupported tag_ids type.")
if not items:
return None
try:
return [str(uuid.UUID(item)) for item in items]
except ValueError as exc:
raise ValueError("Invalid UUID format in tag_ids.") from exc
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
class AppNamePayload(BaseModel):
name: str = Field(..., min_length=1, description="Name to check")
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
icon_background: str | None = Field(default=None, description="Icon background color")
class AppSiteStatusPayload(BaseModel):
enable_site: bool = Field(..., description="Enable or disable site")
class AppApiStatusPayload(BaseModel):
enable_api: bool = Field(..., description="Enable or disable API")
class AppTracePayload(BaseModel):
enabled: bool = Field(..., description="Enable or disable tracing")
tracing_provider: str | None = Field(default=None, description="Tracing provider")
@field_validator("tracing_provider")
@classmethod
def validate_tracing_provider(cls, value: str | None, info) -> str | None:
if info.data.get("enabled") and not value:
raise ValueError("tracing_provider is required when enabled is True")
return value
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AppListQuery)
reg(CreateAppPayload)
reg(UpdateAppPayload)
reg(CopyAppPayload)
reg(AppExportQuery)
reg(AppNamePayload)
reg(AppIconPayload)
reg(AppSiteStatusPayload)
reg(AppApiStatusPayload)
reg(AppTracePayload)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base models first
@ -259,7 +147,22 @@ app_pagination_model = console_ns.model(
class AppListApi(Resource):
@console_ns.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect(console_ns.models[AppListQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
.add_argument(
"mode",
type=str,
location="args",
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
default="all",
help="App mode filter",
)
.add_argument("name", type=str, location="args", help="Filter by app name")
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
)
@console_ns.response(200, "Success", app_pagination_model)
@setup_required
@login_required
@ -269,12 +172,42 @@ class AppListApi(Resource):
"""Get app list"""
current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args_dict = args.model_dump()
def uuid_list(value):
try:
return [str(uuid.UUID(v)) for v in value.split(",")]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"mode",
type=str,
choices=[
"completion",
"chat",
"advanced-chat",
"workflow",
"agent-chat",
"channel",
"all",
],
default="all",
location="args",
required=False,
)
.add_argument("name", type=str, location="args", required=False)
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
)
args = parser.parse_args()
# get app list
app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@ -324,7 +257,19 @@ class AppListApi(Resource):
@console_ns.doc("create_app")
@console_ns.doc(description="Create a new application")
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
@console_ns.expect(
console_ns.model(
"CreateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(201, "App created successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@ -337,10 +282,22 @@ class AppListApi(Resource):
def post(self):
"""Create app"""
current_user, current_tenant_id = current_account_with_tenant()
args = CreateAppPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
app_service = AppService()
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
app = app_service.create_app(current_tenant_id, args, current_user)
return app, 201
@ -372,7 +329,20 @@ class AppApi(Resource):
@console_ns.doc("update_app")
@console_ns.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
@console_ns.expect(
console_ns.model(
"UpdateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
"max_active_requests": fields.Integer(description="Maximum active requests"),
},
)
)
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters")
@ -384,18 +354,28 @@ class AppApi(Resource):
@marshal_with(app_detail_with_site_model)
def put(self, app_model):
"""Update app"""
args = UpdateAppPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
.add_argument("max_active_requests", type=int, location="json")
)
args = parser.parse_args()
app_service = AppService()
args_dict: AppService.ArgsDict = {
"name": args.name,
"description": args.description or "",
"icon_type": args.icon_type or "",
"icon": args.icon or "",
"icon_background": args.icon_background or "",
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
"max_active_requests": args.max_active_requests or 0,
"name": args["name"],
"description": args.get("description", ""),
"icon_type": args.get("icon_type", ""),
"icon": args.get("icon", ""),
"icon_background": args.get("icon_background", ""),
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
"max_active_requests": args.get("max_active_requests", 0),
}
app_model = app_service.update_app(app_model, args_dict)
@ -424,7 +404,18 @@ class AppCopyApi(Resource):
@console_ns.doc("copy_app")
@console_ns.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
@console_ns.expect(
console_ns.model(
"CopyAppRequest",
{
"name": fields.String(description="Name for the copied app"),
"description": fields.String(description="Description for the copied app"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -438,7 +429,15 @@ class AppCopyApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
args = CopyAppPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("name", type=str, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
with Session(db.engine) as session:
import_service = AppDslService(session)
@ -447,11 +446,11 @@ class AppCopyApi(Resource):
account=current_user,
import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content,
name=args.name,
description=args.description,
icon_type=args.icon_type,
icon=args.icon,
icon_background=args.icon_background,
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
)
session.commit()
@ -466,7 +465,11 @@ class AppExportApi(Resource):
@console_ns.doc("export_app")
@console_ns.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"})
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
)
@console_ns.response(
200,
"App exported successfully",
@ -480,23 +483,30 @@ class AppExportApi(Resource):
@edit_permission_required
def get(self, app_model):
"""Export app"""
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# Add include_secret params
parser = (
reqparse.RequestParser()
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
.add_argument("workflow_id", type=str, location="args")
)
args = parser.parse_args()
return {
"data": AppDslService.export_dsl(
app_model=app_model,
include_secret=args.include_secret,
workflow_id=args.workflow_id,
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
)
}
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")
@console_ns.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
@console_ns.expect(parser)
@console_ns.response(200, "Name availability checked")
@setup_required
@login_required
@ -505,10 +515,10 @@ class AppNameApi(Resource):
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
args = AppNamePayload.model_validate(console_ns.payload)
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_name(app_model, args.name)
app_model = app_service.update_app_name(app_model, args["name"])
return app_model
@ -518,7 +528,16 @@ class AppIconApi(Resource):
@console_ns.doc("update_app_icon")
@console_ns.doc(description="Update application icon")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
@console_ns.expect(
console_ns.model(
"AppIconRequest",
{
"icon": fields.String(required=True, description="Icon data"),
"icon_type": fields.String(description="Icon type"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(200, "Icon updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -528,10 +547,15 @@ class AppIconApi(Resource):
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
args = AppIconPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
return app_model
@ -541,7 +565,11 @@ class AppSiteStatus(Resource):
@console_ns.doc("update_app_site_status")
@console_ns.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
@console_ns.expect(
console_ns.model(
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
)
)
@console_ns.response(200, "Site status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -551,10 +579,11 @@ class AppSiteStatus(Resource):
@marshal_with(app_detail_model)
@edit_permission_required
def post(self, app_model):
args = AppSiteStatusPayload.model_validate(console_ns.payload)
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.enable_site)
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
return app_model
@ -564,7 +593,11 @@ class AppApiStatus(Resource):
@console_ns.doc("update_app_api_status")
@console_ns.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
@console_ns.expect(
console_ns.model(
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
)
)
@console_ns.response(200, "API status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -574,10 +607,11 @@ class AppApiStatus(Resource):
@get_app_model
@marshal_with(app_detail_model)
def post(self, app_model):
args = AppApiStatusPayload.model_validate(console_ns.payload)
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.enable_api)
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
return app_model
@ -600,7 +634,15 @@ class AppTraceApi(Resource):
@console_ns.doc("update_app_trace")
@console_ns.doc(description="Update app tracing configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
@console_ns.expect(
console_ns.model(
"AppTraceRequest",
{
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
"tracing_provider": fields.String(required=True, description="Tracing provider"),
},
)
)
@console_ns.response(200, "Trace configuration updated successfully")
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -609,12 +651,17 @@ class AppTraceApi(Resource):
@edit_permission_required
def post(self, app_id):
# add app trace
args = AppTracePayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("enabled", type=bool, required=True, location="json")
.add_argument("tracing_provider", type=str, required=True, location="json")
)
args = parser.parse_args()
OpsTraceManager.update_app_tracing_config(
app_id=app_id,
enabled=args.enabled,
tracing_provider=args.tracing_provider,
enabled=args["enabled"],
tracing_provider=args["tracing_provider"],
)
return {"result": "success"}

View File

@ -1,5 +1,4 @@
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, marshal_with, reqparse
from sqlalchemy.orm import Session
from controllers.console.app.wraps import get_app_model
@ -36,29 +35,23 @@ app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("app_id", type=str, location="json")
)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
@console_ns.expect(parser)
@setup_required
@login_required
@account_initialization_required
@ -68,7 +61,7 @@ class AppImportApi(Resource):
def post(self):
# Check user role first
current_user, _ = current_account_with_tenant()
args = AppImportPayload.model_validate(console_ns.payload)
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
@ -77,15 +70,15 @@ class AppImportApi(Resource):
account = current_user
result = import_service.import_app(
account=account,
import_mode=args.mode,
yaml_content=args.yaml_content,
yaml_url=args.yaml_url,
name=args.name,
description=args.description,
icon_type=args.icon_type,
icon=args.icon,
icon_background=args.icon_background,
app_id=args.app_id,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
app_id=args.get("app_id"),
)
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,8 +1,7 @@
import logging
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import InternalServerError
import services
@ -33,27 +32,6 @@ from services.errors.audio import (
)
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
text: str = Field(..., description="Text to convert")
voice: str | None = Field(default=None, description="Voice name")
streaming: bool | None = Field(default=None, description="Whether to stream audio")
class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@ -114,7 +92,17 @@ class ChatMessageTextApi(Resource):
@console_ns.doc("chat_message_text_to_speech")
@console_ns.doc(description="Convert text to speech for chat messages")
@console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
@console_ns.expect(
console_ns.model(
"TextToSpeechRequest",
{
"message_id": fields.String(description="Message ID"),
"text": fields.String(required=True, description="Text to convert to speech"),
"voice": fields.String(description="Voice to use for TTS"),
"streaming": fields.Boolean(description="Whether to stream the audio"),
},
)
)
@console_ns.response(200, "Text to speech conversion successful")
@console_ns.response(400, "Bad request - Invalid parameters")
@get_app_model
@ -123,14 +111,21 @@ class ChatMessageTextApi(Resource):
@account_initialization_required
def post(self, app_model: App):
try:
payload = TextToSpeechPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
response = AudioService.transcript_tts(
app_model=app_model,
text=payload.text,
voice=payload.voice,
message_id=payload.message_id,
is_draft=True,
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
@ -164,7 +159,9 @@ class TextModesApi(Resource):
@console_ns.doc("get_text_to_speech_voices")
@console_ns.doc(description="Get available TTS voices for a specific language")
@console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
@console_ns.expect(
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
)
@console_ns.response(
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
)
@ -175,11 +172,12 @@ class TextModesApi(Resource):
@account_initialization_required
def get(self, app_model):
try:
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
args = parser.parse_args()
response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id,
language=args.language,
language=args["language"],
)
return response

View File

@ -1,9 +1,7 @@
import logging
from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
@ -37,41 +35,6 @@ from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
model_config_data: dict[str, Any] = Field(..., alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
class CompletionMessagePayload(BaseMessagePayload):
query: str = Field(default="", description="Query text")
class ChatMessagePayload(BaseMessagePayload):
query: str = Field(..., description="User query")
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
console_ns.schema_model(
CompletionMessagePayload.__name__,
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# define completion message api for user
@ -80,7 +43,19 @@ class CompletionMessageApi(Resource):
@console_ns.doc("create_completion_message")
@console_ns.doc(description="Generate completion message for debugging")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
@console_ns.expect(
console_ns.model(
"CompletionMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(description="Query text", default=""),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@console_ns.response(200, "Completion generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App not found")
@ -89,10 +64,18 @@ class CompletionMessageApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model):
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args_model.response_mode != "blocking"
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
try:
@ -154,7 +137,21 @@ class ChatMessageApi(Resource):
@console_ns.doc("create_chat_message")
@console_ns.doc(description="Generate chat message for debugging")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
@console_ns.expect(
console_ns.model(
"ChatMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(required=True, description="User query"),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"conversation_id": fields.String(description="Conversation ID"),
"parent_message_id": fields.String(description="Parent message ID"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@console_ns.response(200, "Chat message generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App or conversation not found")
@ -164,10 +161,20 @@ class ChatMessageApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required
def post(self, app_model):
args_model = ChatMessagePayload.model_validate(console_ns.payload)
args = args_model.model_dump(exclude_none=True, by_alias=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args_model.response_mode != "blocking"
streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request)

View File

@ -1,9 +1,7 @@
from typing import Literal
import sqlalchemy as sa
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask import abort
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import NotFound
@ -16,53 +14,13 @@ from extensions.ext_database import db
from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import TimestampField
from libs.helper import DatetimeString, TimestampField
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseConversationQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword")
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
default="all", description="Annotation status filter"
)
page: int = Field(default=1, ge=1, le=99999, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
class CompletionConversationQuery(BaseConversationQuery):
pass
class ChatConversationQuery(BaseConversationQuery):
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort field and direction"
)
console_ns.schema_model(
CompletionConversationQuery.__name__,
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatConversationQuery.__name__,
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -325,7 +283,22 @@ class CompletionConversationApi(Resource):
@console_ns.doc("list_completion_conversations")
@console_ns.doc(description="Get completion conversations with pagination and filtering")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
)
@console_ns.response(200, "Success", conversation_pagination_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -336,17 +309,32 @@ class CompletionConversationApi(Resource):
@edit_permission_required
def get(self, app_model):
current_user, _ = current_account_with_tenant()
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
)
if args.keyword:
if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike(f"%{args.keyword}%"),
Message.answer.ilike(f"%{args.keyword}%"),
Message.query.ilike(f"%{args['keyword']}%"),
Message.answer.ilike(f"%{args['keyword']}%"),
)
)
@ -354,7 +342,7 @@ class CompletionConversationApi(Resource):
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -366,11 +354,11 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
if args.annotation_status == "annotated":
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args.annotation_status == "not_annotated":
elif args["annotation_status"] == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
@ -379,7 +367,7 @@ class CompletionConversationApi(Resource):
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return conversations
@ -431,7 +419,31 @@ class ChatConversationApi(Resource):
@console_ns.doc("list_chat_conversations")
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
.add_argument(
"sort_by",
type=str,
location="args",
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
default="-updated_at",
help="Sort field and direction",
)
)
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@ -442,7 +454,31 @@ class ChatConversationApi(Resource):
@edit_permission_required
def get(self, app_model):
current_user, _ = current_account_with_tenant()
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
)
args = parser.parse_args()
subquery = (
db.session.query(
@ -454,8 +490,8 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args.keyword:
keyword_filter = f"%{args.keyword}%"
if args["keyword"]:
keyword_filter = f"%{args['keyword']}%"
query = (
query.join(
Message,
@ -478,12 +514,12 @@ class ChatConversationApi(Resource):
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
if start_datetime_utc:
match args.sort_by:
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
@ -491,27 +527,35 @@ class ChatConversationApi(Resource):
if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59)
match args.sort_by:
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args.annotation_status == "annotated":
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args.annotation_status == "not_annotated":
elif args["annotation_status"] == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
if args["message_count_gte"] and args["message_count_gte"] >= 1:
query = (
query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"])
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args.sort_by:
match args["sort_by"]:
case "created_at":
query = query.order_by(Conversation.created_at.asc())
case "-created_at":
@ -523,7 +567,7 @@ class ChatConversationApi(Resource):
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return conversations

View File

@ -1,6 +1,4 @@
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -16,18 +14,6 @@ from libs.login import login_required
from models import ConversationVariable
from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
@ -47,7 +33,11 @@ class ConversationVariablesApi(Resource):
@console_ns.doc("get_conversation_variables")
@console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
@console_ns.expect(
console_ns.parser().add_argument(
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
)
)
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@setup_required
@login_required
@ -55,14 +45,18 @@ class ConversationVariablesApi(Resource):
@get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model)
def get(self, app_model):
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
args = parser.parse_args()
stmt = (
select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at)
)
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
if args["conversation_id"]:
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
else:
raise ValueError("conversation_id is required")
# NOTE: This is a temporary solution to avoid performance issues.
page = 1

View File

@ -1,8 +1,6 @@
from collections.abc import Sequence
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, reqparse
from controllers.console import console_ns
from controllers.console.app.error import (
@ -23,54 +21,21 @@ from libs.login import current_account_with_tenant, login_required
from models import App
from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(RuleGeneratePayload)
reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
@console_ns.route("/rule-generate")
class RuleGenerateApi(Resource):
@console_ns.doc("generate_rule_config")
@console_ns.doc(description="Generate rule configuration using LLM")
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
@console_ns.expect(
console_ns.model(
"RuleGenerateRequest",
{
"instruction": fields.String(required=True, description="Rule generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
},
)
)
@console_ns.response(200, "Rule configuration generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@ -78,15 +43,21 @@ class RuleGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
args = RuleGeneratePayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=args["no_variable"],
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -104,7 +75,19 @@ class RuleGenerateApi(Resource):
class RuleCodeGenerateApi(Resource):
@console_ns.doc("generate_rule_code")
@console_ns.doc(description="Generate code rules using LLM")
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
@console_ns.expect(
console_ns.model(
"RuleCodeGenerateRequest",
{
"instruction": fields.String(required=True, description="Code generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
"code_language": fields.String(
default="javascript", description="Programming language for code generation"
),
},
)
)
@console_ns.response(200, "Code rules generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@ -112,15 +95,22 @@ class RuleCodeGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.code_language,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["code_language"],
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -138,7 +128,15 @@ class RuleCodeGenerateApi(Resource):
class RuleStructuredOutputGenerateApi(Resource):
@console_ns.doc("generate_structured_output")
@console_ns.doc(description="Generate structured output rules using LLM")
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
@console_ns.expect(
console_ns.model(
"StructuredOutputGenerateRequest",
{
"instruction": fields.String(required=True, description="Structured output generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
},
)
)
@console_ns.response(200, "Structured output generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@ -146,14 +144,19 @@ class RuleStructuredOutputGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
instruction=args["instruction"],
model_config=args["model_config"],
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -171,7 +174,20 @@ class RuleStructuredOutputGenerateApi(Resource):
class InstructionGenerateApi(Resource):
@console_ns.doc("generate_instruction")
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
@console_ns.expect(
console_ns.model(
"InstructionGenerateRequest",
{
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
"node_id": fields.String(description="Node ID for workflow context"),
"current": fields.String(description="Current instruction text"),
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
"instruction": fields.String(required=True, description="Instruction for generation"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@console_ns.response(200, "Instruction generated successfully")
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
@console_ns.response(402, "Provider quota exceeded")
@ -179,69 +195,79 @@ class InstructionGenerateApi(Resource):
@login_required
@account_initialization_required
def post(self):
args = InstructionGeneratePayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("flow_id", type=str, required=True, default="", location="json")
.add_argument("node_id", type=str, required=False, default="", location="json")
.add_argument("current", type=str, required=False, default="", location="json")
.add_argument("language", type=str, required=False, default="javascript", location="json")
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("ideal_output", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args.language)), None
(p for p in providers if p.is_accept_language(args["language"])), None
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args.flow_id).first()
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app:
return {"error": f"app {args.flow_id} not found"}, 400
return {"error": f"app {args['flow_id']} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"error": f"workflow {args.flow_id} not found"}, 400
return {"error": f"workflow {args['flow_id']} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"]
node = [node for node in nodes if node["id"] == args.node_id]
node = [node for node in nodes if node["id"] == args["node_id"]]
if len(node) == 0:
return {"error": f"node {args.node_id} not found"}, 400
return {"error": f"node {args['node_id']} not found"}, 400
node_type = node[0]["data"]["type"]
match node_type:
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
instruction=args["instruction"],
model_config=args["model_config"],
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
instruction=args["instruction"],
model_config=args["model_config"],
code_language=args["language"],
)
case _:
return {"error": f"invalid node type: {node_type}"}
if args.node_id == "" and args.current != "": # For legacy app without a workflow
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_tenant_id,
flow_id=args.flow_id,
current=args.current,
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
flow_id=args["flow_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
)
if args.node_id != "" and args.current != "": # For workflow node
if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_tenant_id,
flow_id=args.flow_id,
node_id=args.node_id,
current=args.current,
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
flow_id=args["flow_id"],
node_id=args["node_id"],
current=args["current"],
instruction=args["instruction"],
model_config=args["model_config"],
ideal_output=args["ideal_output"],
workflow_service=WorkflowService(),
)
return {"error": "incompatible parameters"}, 400
@ -259,15 +285,24 @@ class InstructionGenerateApi(Resource):
class InstructionGenerationTemplateApi(Resource):
@console_ns.doc("get_instruction_template")
@console_ns.doc(description="Get instruction generation template")
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
@console_ns.expect(
console_ns.model(
"InstructionTemplateRequest",
{
"instruction": fields.String(required=True, description="Template instruction"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@console_ns.response(200, "Template retrieved successfully")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@account_initialization_required
def post(self):
args = InstructionTemplatePayload.model_validate(console_ns.payload)
match args.type:
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args()
match args["type"]:
case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
@ -277,4 +312,4 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args.type}")
raise ValueError(f"Invalid type: {args['type']}")

View File

@ -1,8 +1,7 @@
import json
from enum import StrEnum
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -13,8 +12,6 @@ from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields)
@ -24,22 +21,6 @@ class AppMCPServerStatus(StrEnum):
INACTIVE = "inactive"
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/apps/<uuid:app_id>/server")
class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@ -58,7 +39,15 @@ class AppMCPServerController(Resource):
@console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.expect(
console_ns.model(
"MCPServerCreateRequest",
{
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
},
)
)
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@ -69,16 +58,21 @@ class AppMCPServerController(Resource):
@edit_permission_required
def post(self, app_model):
_, current_tenant_id = current_account_with_tenant()
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
)
args = parser.parse_args()
description = payload.description
description = args.get("description")
if not description:
description = app_model.description or ""
server = AppMCPServer(
name=app_model.name,
description=description,
parameters=json.dumps(payload.parameters, ensure_ascii=False),
parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
tenant_id=current_tenant_id,
@ -91,7 +85,17 @@ class AppMCPServerController(Resource):
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.expect(
console_ns.model(
"MCPServerUpdateRequest",
{
"id": fields.String(required=True, description="Server ID"),
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
"status": fields.String(description="Server status"),
},
)
)
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@ -102,12 +106,19 @@ class AppMCPServerController(Resource):
@marshal_with(app_server_model)
@edit_permission_required
def put(self, app_model):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
parser = (
reqparse.RequestParser()
.add_argument("id", type=str, required=True, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
.add_argument("status", type=str, required=False, location="json")
)
args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server:
raise NotFound()
description = payload.description
description = args.get("description")
if description is None:
pass
elif not description:
@ -115,11 +126,11 @@ class AppMCPServerController(Resource):
else:
server.description = description
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if payload.status:
if payload.status not in [status.value for status in AppMCPServerStatus]:
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
if args["status"]:
if args["status"] not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status")
server.status = payload.status
server.status = args["status"]
db.session.commit()
return server

View File

@ -1,9 +1,7 @@
import logging
from typing import Literal
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
@ -35,68 +33,6 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("first_id", mode="before")
@classmethod
def empty_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
class FeedbackExportQuery(BaseModel):
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
@field_validator("has_comment", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool | None:
if isinstance(value, bool) or value is None:
return value
lowered = value.lower()
if lowered in {"true", "1", "yes", "on"}:
return True
if lowered in {"false", "0", "no", "off"}:
return False
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -221,7 +157,12 @@ class ChatMessageListApi(Resource):
@console_ns.doc("list_chat_messages")
@console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
)
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(404, "Conversation not found")
@login_required
@ -231,21 +172,27 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
conversation = (
db.session.query(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args.first_id:
if args["first_id"]:
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first()
)
@ -260,7 +207,7 @@ class ChatMessageListApi(Resource):
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(args.limit)
.limit(args["limit"])
.all()
)
else:
@ -268,12 +215,12 @@ class ChatMessageListApi(Resource):
db.session.query(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args.limit)
.limit(args["limit"])
.all()
)
# Initialize has_more based on whether we have a full page
if len(history_messages) == args.limit:
if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page
has_more = db.session.scalar(
@ -291,7 +238,7 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
@ -299,7 +246,15 @@ class MessageFeedbackApi(Resource):
@console_ns.doc("create_message_feedback")
@console_ns.doc(description="Create or update message feedback (like/dislike)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
@console_ns.expect(
console_ns.model(
"MessageFeedbackRequest",
{
"message_id": fields.String(required=True, description="Message ID"),
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
},
)
)
@console_ns.response(200, "Feedback updated successfully")
@console_ns.response(404, "Message not found")
@console_ns.response(403, "Insufficient permissions")
@ -310,9 +265,14 @@ class MessageFeedbackApi(Resource):
def post(self, app_model):
current_user, _ = current_account_with_tenant()
args = MessageFeedbackPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args()
message_id = str(args.message_id)
message_id = str(args["message_id"])
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
@ -321,23 +281,18 @@ class MessageFeedbackApi(Resource):
feedback = message.admin_feedback
if not args.rating and feedback:
if not args["rating"] and feedback:
db.session.delete(feedback)
elif args.rating and feedback:
feedback.rating = args.rating
feedback.content = args.content
elif not args.rating and not feedback:
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
rating_value = args.rating
if rating_value is None:
raise ValueError("rating is required to create feedback")
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating_value,
content=args.content,
rating=args["rating"],
from_source="admin",
from_account_id=current_user.id,
)
@ -414,12 +369,24 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions}
# Shared parser for feedback export (used for both documentation and runtime parsing)
feedback_export_parser = (
console_ns.parser()
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
)
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
class MessageFeedbackExportApi(Resource):
@console_ns.doc("export_feedbacks")
@console_ns.doc(description="Export user feedback data for Google Sheets")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
@console_ns.expect(feedback_export_parser)
@console_ns.response(200, "Feedback data exported successfully")
@console_ns.response(400, "Invalid parameters")
@console_ns.response(500, "Internal server error")
@ -428,7 +395,7 @@ class MessageFeedbackExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = feedback_export_parser.parse_args()
# Import the service function
from services.feedback_service import FeedbackService
@ -436,12 +403,12 @@ class MessageFeedbackExportApi(Resource):
try:
export_data = FeedbackService.export_feedbacks(
app_id=app_model.id,
from_source=args.from_source,
rating=args.rating,
has_comment=args.has_comment,
start_date=args.start_date,
end_date=args.end_date,
format_type=args.format,
from_source=args.get("from_source"),
rating=args.get("rating"),
has_comment=args.get("has_comment"),
start_date=args.get("start_date"),
end_date=args.get("end_date"),
format_type=args.get("format", "csv"),
)
return export_data

View File

@ -1,8 +1,4 @@
from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
@ -11,26 +7,6 @@ from controllers.console.wraps import account_initialization_required, setup_req
from libs.login import login_required
from services.ops_service import OpsService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TraceProviderQuery(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
class TraceConfigPayload(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
console_ns.schema_model(
TraceProviderQuery.__name__,
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/trace-config")
class TraceAppConfigApi(Resource):
@ -41,7 +17,11 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("get_trace_app_config")
@console_ns.doc(description="Get tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
@console_ns.expect(
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
)
@ -50,10 +30,11 @@ class TraceAppConfigApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id):
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
if not trace_config:
return {"has_not_configured": True}
return trace_config
@ -63,7 +44,15 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("create_trace_app_config")
@console_ns.doc(description="Create a new tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
@console_ns.expect(
console_ns.model(
"TraceConfigCreateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
},
)
)
@console_ns.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
)
@ -73,11 +62,16 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def post(self, app_id):
"""Create a new trace app configuration"""
args = TraceConfigPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
try:
result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
)
if not result:
raise TracingConfigIsExist()
@ -90,7 +84,15 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("update_trace_app_config")
@console_ns.doc(description="Update an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
@console_ns.expect(
console_ns.model(
"TraceConfigUpdateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
},
)
)
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required
@ -98,11 +100,16 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def patch(self, app_id):
"""Update an existing trace app configuration"""
args = TraceConfigPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
try:
result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"]
)
if not result:
raise TracingConfigNotExist()
@ -113,7 +120,11 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("delete_trace_app_config")
@console_ns.doc(description="Delete an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
@console_ns.expect(
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.response(204, "Tracing configuration deleted successfully")
@console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required
@ -121,10 +132,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required
def delete(self, app_id):
"""Delete an existing trace app configuration"""
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
args = parser.parse_args()
try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204

View File

@ -1,7 +1,4 @@
from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
@ -19,50 +16,69 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
icon_type: str | None = Field(default=None)
icon: str | None = Field(default=None)
icon_background: str | None = Field(default=None)
description: str | None = Field(default=None)
default_language: str | None = Field(default=None)
chat_color_theme: str | None = Field(default=None)
chat_color_theme_inverted: bool | None = Field(default=None)
customize_domain: str | None = Field(default=None)
copyright: str | None = Field(default=None)
privacy_policy: str | None = Field(default=None)
custom_disclaimer: str | None = Field(default=None)
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
prompt_public: bool | None = Field(default=None)
show_workflow_steps: bool | None = Field(default=None)
use_icon_as_answer_icon: bool | None = Field(default=None)
@field_validator("default_language")
@classmethod
def validate_language(cls, value: str | None) -> str | None:
if value is None:
return value
return supported_language(value)
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields)
def parse_app_site_args():
parser = (
reqparse.RequestParser()
.add_argument("title", type=str, required=False, location="json")
.add_argument("icon_type", type=str, required=False, location="json")
.add_argument("icon", type=str, required=False, location="json")
.add_argument("icon_background", type=str, required=False, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("default_language", type=supported_language, required=False, location="json")
.add_argument("chat_color_theme", type=str, required=False, location="json")
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
.add_argument("customize_domain", type=str, required=False, location="json")
.add_argument("copyright", type=str, required=False, location="json")
.add_argument("privacy_policy", type=str, required=False, location="json")
.add_argument("custom_disclaimer", type=str, required=False, location="json")
.add_argument(
"customize_token_strategy",
type=str,
choices=["must", "allow", "not_allow"],
required=False,
location="json",
)
.add_argument("prompt_public", type=bool, required=False, location="json")
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/site")
class AppSite(Resource):
@console_ns.doc("update_app_site")
@console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
@console_ns.expect(
console_ns.model(
"AppSiteRequest",
{
"title": fields.String(description="Site title"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"description": fields.String(description="Site description"),
"default_language": fields.String(description="Default language"),
"chat_color_theme": fields.String(description="Chat color theme"),
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
"customize_domain": fields.String(description="Custom domain"),
"copyright": fields.String(description="Copyright text"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"customize_token_strategy": fields.String(
enum=["must", "allow", "not_allow"], description="Token strategy"
),
"prompt_public": fields.Boolean(description="Make prompt public"),
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
},
)
)
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found")
@ -73,7 +89,7 @@ class AppSite(Resource):
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
args = parse_app_site_args()
current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
@ -97,7 +113,7 @@ class AppSite(Resource):
"show_workflow_steps",
"use_icon_as_answer_icon",
]:
value = getattr(args, attr_name)
value = args.get(attr_name)
if value is not None:
setattr(site, attr_name, value)

View File

@ -1,9 +1,8 @@
from decimal import Decimal
import sqlalchemy as sa
from flask import abort, jsonify, request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from flask import abort, jsonify
from flask_restx import Resource, fields, reqparse
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -11,37 +10,21 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import convert_datetime_to_date
from libs.helper import DatetimeString, convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required
from models import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class StatisticTimeRangeQuery(BaseModel):
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def empty_string_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
StatisticTimeRangeQuery.__name__,
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
class DailyMessageStatistic(Resource):
@console_ns.doc("get_daily_message_statistics")
@console_ns.doc(description="Get daily message statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(
console_ns.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.response(
200,
"Daily message statistics retrieved successfully",
@ -54,7 +37,12 @@ class DailyMessageStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -69,7 +57,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -93,12 +81,19 @@ WHERE
return jsonify({"data": response_data})
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource):
@console_ns.doc("get_daily_conversation_statistics")
@console_ns.doc(description="Get daily conversation statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Daily conversation statistics retrieved successfully",
@ -111,7 +106,7 @@ class DailyConversationStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -126,7 +121,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -154,7 +149,7 @@ class DailyTerminalsStatistic(Resource):
@console_ns.doc("get_daily_terminals_statistics")
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Daily terminal statistics retrieved successfully",
@ -167,7 +162,7 @@ class DailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -182,7 +177,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -211,7 +206,7 @@ class DailyTokenCostStatistic(Resource):
@console_ns.doc("get_daily_token_cost_statistics")
@console_ns.doc(description="Get daily token cost statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Daily token cost statistics retrieved successfully",
@ -224,7 +219,7 @@ class DailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -240,7 +235,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -271,7 +266,7 @@ class AverageSessionInteractionStatistic(Resource):
@console_ns.doc("get_average_session_interaction_statistics")
@console_ns.doc(description="Get average session interaction statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Average session interaction statistics retrieved successfully",
@ -284,7 +279,7 @@ class AverageSessionInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("c.created_at")
sql_query = f"""SELECT
@ -307,7 +302,7 @@ FROM
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -347,7 +342,7 @@ class UserSatisfactionRateStatistic(Resource):
@console_ns.doc("get_user_satisfaction_rate_statistics")
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"User satisfaction rate statistics retrieved successfully",
@ -360,7 +355,7 @@ class UserSatisfactionRateStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("m.created_at")
sql_query = f"""SELECT
@ -379,7 +374,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -413,7 +408,7 @@ class AverageResponseTimeStatistic(Resource):
@console_ns.doc("get_average_response_time_statistics")
@console_ns.doc(description="Get average response time statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Average response time statistics retrieved successfully",
@ -426,7 +421,7 @@ class AverageResponseTimeStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -441,7 +436,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -470,7 +465,7 @@ class TokensPerSecondStatistic(Resource):
@console_ns.doc("get_tokens_per_second_statistics")
@console_ns.doc(description="Get tokens per second statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.expect(parser)
@console_ns.response(
200,
"Tokens per second statistics retrieved successfully",
@ -482,7 +477,7 @@ class TokensPerSecondStatistic(Resource):
@account_initialization_required
def get(self, app_model):
account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
@ -500,7 +495,7 @@ WHERE
assert account.timezone is not None
try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))

View File

@ -1,11 +1,10 @@
import json
import logging
from collections.abc import Sequence
from typing import Any
from typing import cast
from flask import abort, request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, inputs, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -50,7 +49,6 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__)
LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@ -109,104 +107,6 @@ if workflow_run_node_execution_model is None:
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
class SyncDraftWorkflowPayload(BaseModel):
graph: dict[str, Any]
features: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
class BaseWorkflowRunPayload(BaseModel):
files: list[dict[str, Any]] | None = None
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any] | None = None
query: str = ""
conversation_id: str | None = None
parent_message_id: str | None = None
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class IterationNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class LoopNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
query: str = ""
class PublishWorkflowPayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class ConvertToWorkflowPayload(BaseModel):
name: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str
class DraftWorkflowTriggerRunAllPayload(BaseModel):
node_ids: list[str]
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(SyncDraftWorkflowPayload)
reg(AdvancedChatWorkflowRunPayload)
reg(IterationNodeRunPayload)
reg(LoopNodeRunPayload)
reg(DraftWorkflowRunPayload)
reg(DraftWorkflowNodeRunPayload)
reg(PublishWorkflowPayload)
reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
# at the controller level rather than in the workflow logic. This would improve separation
# of concerns and make the code more maintainable.
@ -258,7 +158,18 @@ class DraftWorkflowApi(Resource):
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@console_ns.doc("sync_draft_workflow")
@console_ns.doc(description="Sync draft workflow configuration")
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
@console_ns.expect(
console_ns.model(
"SyncDraftWorkflowRequest",
{
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
"features": fields.Raw(required=True, description="Workflow features configuration"),
"hash": fields.String(description="Workflow hash for validation"),
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@console_ns.response(
200,
"Draft workflow synced successfully",
@ -282,23 +193,36 @@ class DraftWorkflowApi(Resource):
content_type = request.headers.get("Content-Type", "")
payload_data: dict[str, Any] | None = None
if "application/json" in content_type:
payload_data = request.get_json(silent=True)
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
parser = (
reqparse.RequestParser()
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("features", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type:
try:
payload_data = json.loads(request.data.decode("utf-8"))
data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
raise ValueError("graph or features is not a dict")
args = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
args = args_model.model_dump()
workflow_service = WorkflowService()
try:
@ -334,7 +258,17 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@console_ns.doc("run_advanced_chat_draft_workflow")
@console_ns.doc(description="Run draft workflow for advanced chat application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"AdvancedChatWorkflowRunRequest",
{
"query": fields.String(required=True, description="User query"),
"inputs": fields.Raw(description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
"conversation_id": fields.String(description="Conversation ID"),
},
)
)
@console_ns.response(200, "Workflow run started successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(403, "Permission denied")
@ -349,8 +283,16 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
args = args_model.model_dump(exclude_none=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, location="json")
.add_argument("query", type=str, required=True, location="json", default="")
.add_argument("files", type=list, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
@ -380,7 +322,15 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"IterationNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Iteration node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -394,7 +344,8 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
current_user, _ = current_account_with_tenant()
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
@ -418,7 +369,15 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_workflow_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"WorkflowIterationNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Workflow iteration node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -432,7 +391,8 @@ class WorkflowDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
current_user, _ = current_account_with_tenant()
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
@ -456,7 +416,15 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"LoopNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Loop node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -470,7 +438,8 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_loop(
@ -494,7 +463,15 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_workflow_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"WorkflowLoopNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Workflow loop node run started successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -508,7 +485,8 @@ class WorkflowDraftRunLoopNodeApi(Resource):
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_loop(
@ -532,7 +510,15 @@ class DraftWorkflowRunApi(Resource):
@console_ns.doc("run_draft_workflow")
@console_ns.doc(description="Run draft workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"DraftWorkflowRunRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
},
)
)
@console_ns.response(200, "Draft workflow run started successfully")
@console_ns.response(403, "Permission denied")
@setup_required
@ -545,7 +531,12 @@ class DraftWorkflowRunApi(Resource):
Run draft workflow
"""
current_user, _ = current_account_with_tenant()
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
@ -597,7 +588,14 @@ class DraftWorkflowNodeRunApi(Resource):
@console_ns.doc("run_draft_workflow_node")
@console_ns.doc(description="Run draft workflow node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
@console_ns.expect(
console_ns.model(
"DraftWorkflowNodeRunRequest",
{
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found")
@ -612,10 +610,15 @@ class DraftWorkflowNodeRunApi(Resource):
Run draft workflow node
"""
current_user, _ = current_account_with_tenant()
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
args = args_model.model_dump(exclude_none=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("query", type=str, required=False, location="json", default="")
.add_argument("files", type=list, location="json", default=[])
)
args = parser.parse_args()
user_inputs = args_model.inputs
user_inputs = args.get("inputs")
if user_inputs is None:
raise ValueError("missing inputs")
@ -640,6 +643,13 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution
parser_publish = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource):
@console_ns.doc("get_published_workflow")
@ -664,7 +674,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None
return workflow
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
@console_ns.expect(parser_publish)
@setup_required
@login_required
@account_initialization_required
@ -676,7 +686,13 @@ class PublishedWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
args = parser_publish.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
workflow_service = WorkflowService()
with Session(db.engine) as session:
@ -725,6 +741,9 @@ class DefaultBlockConfigsApi(Resource):
return workflow_service.get_default_block_configs()
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource):
@console_ns.doc("get_default_block_config")
@ -732,7 +751,7 @@ class DefaultBlockConfigApi(Resource):
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@console_ns.response(200, "Default block configuration retrieved successfully")
@console_ns.response(404, "Block type not found")
@console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
@console_ns.expect(parser_block)
@setup_required
@login_required
@account_initialization_required
@ -742,12 +761,14 @@ class DefaultBlockConfigApi(Resource):
"""
Get default block config
"""
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser_block.parse_args()
q = args.get("q")
filters = None
if args.q:
if q:
try:
filters = json.loads(args.q)
filters = json.loads(args.get("q", ""))
except json.JSONDecodeError:
raise ValueError("Invalid filters")
@ -756,9 +777,18 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
parser_convert = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource):
@console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
@console_ns.expect(parser_convert)
@console_ns.doc("convert_to_workflow")
@console_ns.doc(description="Convert application to workflow mode")
@console_ns.doc(params={"app_id": "Application ID"})
@ -778,8 +808,10 @@ class ConvertToWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
if request.data:
args = parser_convert.parse_args()
else:
args = {}
# convert to workflow mode
workflow_service = WorkflowService()
@ -791,9 +823,18 @@ class ConvertToWorkflowApi(Resource):
}
parser_workflows = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
@console_ns.expect(parser_workflows)
@console_ns.doc("get_all_published_workflows")
@console_ns.doc(description="Get all published workflows for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@ -810,15 +851,16 @@ class PublishedAllWorkflowApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args.page
limit = args.limit
user_id = args.user_id
named_only = args.named_only
args = parser_workflows.parse_args()
page = args["page"]
limit = args["limit"]
user_id = args.get("user_id")
named_only = args.get("named_only", False)
if user_id:
if user_id != current_user.id:
raise Forbidden()
user_id = cast(str, user_id)
workflow_service = WorkflowService()
with Session(db.engine) as session:
@ -844,7 +886,15 @@ class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id")
@console_ns.doc(description="Update workflow by ID")
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
@console_ns.expect(
console_ns.model(
"UpdateWorkflowRequest",
{
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@console_ns.response(200, "Workflow updated successfully", workflow_model)
@console_ns.response(404, "Workflow not found")
@console_ns.response(403, "Permission denied")
@ -859,14 +909,25 @@ class WorkflowByIdApi(Resource):
Update workflow attributes
"""
current_user, _ = current_account_with_tenant()
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
# Prepare update data
update_data = {}
if args.marked_name is not None:
update_data["marked_name"] = args.marked_name
if args.marked_comment is not None:
update_data["marked_comment"] = args.marked_comment
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]
if not update_data:
return {"message": "No valid fields to update"}, 400
@ -979,8 +1040,11 @@ class DraftWorkflowTriggerRunApi(Resource):
Poll for trigger events and execute full workflow when event arrives
"""
current_user, _ = current_account_with_tenant()
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
node_id = args.node_id
parser = reqparse.RequestParser().add_argument(
"node_id", type=str, required=True, location="json", nullable=False
)
args = parser.parse_args()
node_id = args["node_id"]
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:
@ -1108,7 +1172,14 @@ class DraftWorkflowTriggerRunAllApi(Resource):
@console_ns.doc("draft_workflow_trigger_run_all")
@console_ns.doc(description="Full workflow debug when the start node is a trigger")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
@console_ns.expect(
console_ns.model(
"DraftWorkflowTriggerRunAllRequest",
{
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
},
)
)
@console_ns.response(200, "Workflow executed successfully")
@console_ns.response(403, "Permission denied")
@console_ns.response(500, "Internal server error")
@ -1123,8 +1194,11 @@ class DraftWorkflowTriggerRunAllApi(Resource):
"""
current_user, _ = current_account_with_tenant()
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
node_ids = args.node_ids
parser = reqparse.RequestParser().add_argument(
"node_ids", type=list, required=True, location="json", nullable=False
)
args = parser.parse_args()
node_ids = args["node_ids"]
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow:

View File

@ -1,9 +1,6 @@
from datetime import datetime
from dateutil.parser import isoparse
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -17,48 +14,6 @@ from models import App
from models.model import AppMode
from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
status: WorkflowExecutionStatus | None = Field(
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
)
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
created_by_account: str | None = Field(default=None, description="Filter by account")
detail: bool = Field(default=False, description="Whether to return detailed logs")
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
@field_validator("created_at__before", "created_at__after", mode="before")
@classmethod
def parse_datetime(cls, value: str | None) -> datetime | None:
if value in (None, ""):
return None
return isoparse(value) # type: ignore
@field_validator("detail", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
lowered = value.lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
raise ValueError("Invalid boolean value for detail")
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
@ -68,7 +23,19 @@ class WorkflowAppLogApi(Resource):
@console_ns.doc("get_workflow_app_logs")
@console_ns.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
@console_ns.doc(
params={
"keyword": "Search keyword for filtering logs",
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
"created_at__before": "Filter logs created before this timestamp",
"created_at__after": "Filter logs created after this timestamp",
"created_by_end_user_session_id": "Filter by end user session ID",
"created_by_account": "Filter by account",
"detail": "Whether to return detailed logs",
"page": "Page number (1-99999)",
"limit": "Number of items per page (1-100)",
}
)
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@setup_required
@login_required
@ -79,7 +46,44 @@ class WorkflowAppLogApi(Resource):
"""
Get workflow app logs
"""
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("detail", type=bool, location="args", required=False, default=False)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()

View File

@ -1,11 +1,10 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, NoReturn, ParamSpec, TypeVar
from typing import NoReturn, ParamSpec, TypeVar
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session
from controllers.console import console_ns
@ -30,27 +29,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowDraftVariableListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
class WorkflowDraftVariableUpdatePayload(BaseModel):
name: str | None = Field(default=None, description="Variable name")
value: Any | None = Field(default=None, description="Variable value")
console_ns.schema_model(
WorkflowDraftVariableListQuery.__name__,
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkflowDraftVariableUpdatePayload.__name__,
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _convert_values_to_json_serializable_object(value: Segment):
@ -79,6 +57,22 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = (
reqparse.RequestParser()
.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
return parser
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
@ -207,7 +201,7 @@ def _api_prerequisite(f: Callable[P, R]):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource):
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
@console_ns.expect(_create_pagination_parser())
@console_ns.doc("get_workflow_variables")
@console_ns.doc(description="Get draft workflow variables")
@console_ns.doc(params={"app_id": "Application ID"})
@ -221,7 +215,8 @@ class WorkflowVariableCollectionApi(Resource):
"""
Get draft workflow
"""
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = _create_pagination_parser()
args = parser.parse_args()
# fetch draft workflow by app_model
workflow_service = WorkflowService()
@ -328,7 +323,15 @@ class VariableApi(Resource):
@console_ns.doc("update_variable")
@console_ns.doc(description="Update a workflow variable")
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
@console_ns.expect(
console_ns.model(
"UpdateVariableRequest",
{
"name": fields.String(description="Variable name"),
"value": fields.Raw(description="Variable value"),
},
)
)
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@ -355,10 +358,16 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args(strict=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
@ -366,8 +375,8 @@ class VariableApi(Resource):
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args_model.name
raw_value = args_model.value
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
if new_name is None and raw_value is None:
return variable

View File

@ -1,8 +1,7 @@
from typing import Literal, cast
from typing import cast
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -93,51 +92,70 @@ workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _parse_workflow_run_list_args():
"""
Parse common arguments for workflow run list endpoints.
class WorkflowRunListQuery(BaseModel):
last_id: str | None = Field(default=None, description="Last run ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
return parser.parse_args()
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
)
@field_validator("last_id")
@classmethod
def validate_last_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class WorkflowRunCountQuery(BaseModel):
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
)
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
)
@field_validator("time_range")
@classmethod
def validate_time_range(cls, value: str | None) -> str | None:
if value is None:
return value
return time_duration(value)
console_ns.schema_model(
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
WorkflowRunCountQuery.__name__,
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
@ -152,7 +170,6 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
@setup_required
@login_required
@ -163,13 +180,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
"""
Get advanced chat app workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args = _parse_workflow_run_list_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
@ -201,7 +217,6 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -211,13 +226,12 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
"""
Get advanced chat workflow runs count statistics
"""
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args = _parse_workflow_run_count_args()
# Default to DEBUGGING if not specified
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
@ -245,7 +259,6 @@ class WorkflowRunListApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -255,13 +268,12 @@ class WorkflowRunListApi(Resource):
"""
Get workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args = _parse_workflow_run_list_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)
@ -293,7 +305,6 @@ class WorkflowRunCountApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required
@login_required
@account_initialization_required
@ -303,13 +314,12 @@ class WorkflowRunCountApi(Resource):
"""
Get workflow runs count statistics
"""
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
args = _parse_workflow_run_count_args()
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from)
if args_model.triggered_from
WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING
)

View File

@ -1,6 +1,5 @@
from flask import abort, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from flask import abort, jsonify
from flask_restx import Resource, reqparse
from sqlalchemy.orm import sessionmaker
from controllers.console import console_ns
@ -8,31 +7,12 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowStatisticQuery(BaseModel):
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
WorkflowStatisticQuery.__name__,
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource):
@ -44,7 +24,9 @@ class WorkflowDailyRunsStatistic(Resource):
@console_ns.doc("get_workflow_daily_runs_statistic")
@console_ns.doc(description="Get workflow daily runs statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily runs statistics retrieved successfully")
@get_app_model
@setup_required
@ -53,12 +35,17 @@ class WorkflowDailyRunsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -84,7 +71,9 @@ class WorkflowDailyTerminalsStatistic(Resource):
@console_ns.doc("get_workflow_daily_terminals_statistic")
@console_ns.doc(description="Get workflow daily terminals statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
@get_app_model
@setup_required
@ -93,12 +82,17 @@ class WorkflowDailyTerminalsStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -124,7 +118,9 @@ class WorkflowDailyTokenCostStatistic(Resource):
@console_ns.doc("get_workflow_daily_token_cost_statistic")
@console_ns.doc(description="Get workflow daily token cost statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
@get_app_model
@setup_required
@ -133,12 +129,17 @@ class WorkflowDailyTokenCostStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))
@ -164,7 +165,9 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@console_ns.doc("get_workflow_average_app_interaction_statistic")
@console_ns.doc(description="Get workflow average app interaction statistics")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
@console_ns.doc(
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
@setup_required
@login_required
@ -173,12 +176,17 @@ class WorkflowAverageAppInteractionStatistic(Resource):
def get(self, app_model):
account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None
try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e:
abort(400, description=str(e))

View File

@ -1,8 +1,6 @@
import logging
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -20,30 +18,16 @@ from ..app.wraps import get_app_model
from ..wraps import account_initialization_required, edit_permission_required, setup_required
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class Parser(BaseModel):
node_id: str
class ParserEnable(BaseModel):
trigger_id: str
enable_trigger: bool
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, help="Node ID is required")
@console_ns.route("/apps/<uuid:app_id>/workflows/triggers/webhook")
class WebhookTriggerApi(Resource):
"""Webhook Trigger API"""
@console_ns.expect(console_ns.models[Parser.__name__])
@console_ns.expect(parser)
@setup_required
@login_required
@account_initialization_required
@ -51,9 +35,9 @@ class WebhookTriggerApi(Resource):
@marshal_with(webhook_trigger_fields)
def get(self, app_model: App):
"""Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
node_id = args.node_id
node_id = str(args["node_id"])
with Session(db.engine) as session:
# Get webhook trigger for this app and node
@ -112,9 +96,16 @@ class AppTriggersApi(Resource):
return {"data": triggers}
parser_enable = (
reqparse.RequestParser()
.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__])
@console_ns.expect(parser_enable)
@setup_required
@login_required
@account_initialization_required
@ -123,11 +114,12 @@ class AppTriggerEnableApi(Resource):
@marshal_with(trigger_fields)
def post(self, app_model: App):
"""Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload)
args = parser_enable.parse_args()
assert current_user.current_tenant_id is not None
trigger_id = args.trigger_id
trigger_id = args["trigger_id"]
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
@ -142,7 +134,7 @@ class AppTriggerEnableApi(Resource):
raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)

View File

@ -1,53 +1,28 @@
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, reqparse
from constants.languages import supported_language
from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, timezone
from libs.helper import StrLen, email, extract_remote_ip, timezone
from models import AccountStatus
from services.account_service import RegisterService
from services.account_service import AccountService, RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ActivateCheckQuery(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
class ActivatePayload(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
name: str = Field(..., max_length=30)
interface_language: str = Field(...)
timezone: str = Field(...)
@field_validator("interface_language")
@classmethod
def validate_lang(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_tz(cls, value: str) -> str:
return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
active_check_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
)
@console_ns.route("/activate/check")
class ActivateCheckApi(Resource):
@console_ns.doc("check_activation_token")
@console_ns.doc(description="Check if activation token is valid")
@console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
@console_ns.expect(active_check_parser)
@console_ns.response(
200,
"Success",
@ -60,11 +35,11 @@ class ActivateCheckApi(Resource):
),
)
def get(self):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = active_check_parser.parse_args()
workspaceId = args.workspace_id
reg_email = args.email
token = args.token
workspaceId = args["workspace_id"]
reg_email = args["email"]
token = args["token"]
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
@ -81,11 +56,22 @@ class ActivateCheckApi(Resource):
return {"is_valid": False}
active_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
.add_argument("email", type=email, required=False, nullable=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
)
@console_ns.route("/activate")
class ActivateApi(Resource):
@console_ns.doc("activate_account")
@console_ns.doc(description="Activate account with invitation token")
@console_ns.expect(console_ns.models[ActivatePayload.__name__])
@console_ns.expect(active_parser)
@console_ns.response(
200,
"Account activated successfully",
@ -93,27 +79,30 @@ class ActivateApi(Resource):
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.Raw(description="Login token data"),
},
),
)
@console_ns.response(400, "Already activated or invalid token")
def post(self):
args = ActivatePayload.model_validate(console_ns.payload)
args = active_parser.parse_args()
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"])
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"])
account = invitation["account"]
account.name = args.name
account.name = args["name"]
account.interface_language = args.interface_language
account.timezone = args.timezone
account.interface_language = args["interface_language"]
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
db.session.commit()
return {"result": "success"}
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token_pair.model_dump()}

View File

@ -1,26 +1,12 @@
from flask_restx import Resource
from pydantic import BaseModel, Field
from flask_restx import Resource, reqparse
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from .. import console_ns
from ..auth.error import ApiKeyAuthFailedError
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ApiKeyAuthBindingPayload(BaseModel):
category: str = Field(...)
provider: str = Field(...)
credentials: dict = Field(...)
console_ns.schema_model(
ApiKeyAuthBindingPayload.__name__,
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
from ..wraps import account_initialization_required, setup_required
@console_ns.route("/api-key-auth/data-source")
@ -54,15 +40,19 @@ class ApiKeyAuthDataSourceBinding(Resource):
@login_required
@account_initialization_required
@is_admin_or_owner_required
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
def post(self):
# The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant()
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
data = payload.model_dump()
ApiKeyAuthService.validate_api_key_auth_args(data)
parser = (
reqparse.RequestParser()
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try:
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200

View File

@ -5,11 +5,12 @@ from flask import current_app, redirect, request
from flask_restx import Resource, fields
from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from ..wraps import account_initialization_required, setup_required
logger = logging.getLogger(__name__)

View File

@ -1,6 +1,5 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -15,45 +14,16 @@ from controllers.console.auth.error import (
InvalidTokenError,
PasswordMismatchError,
)
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
from models import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from ..error import AccountInFreezeError, EmailSendIpLimitError
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EmailRegisterSendPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
language: str | None = Field(default=None, description="Language code")
class EmailRegisterValidityPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class EmailRegisterResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/email-register/send-email")
class EmailRegisterSendEmailApi(Resource):
@ -61,22 +31,27 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
language = "en-US"
if args.language in languages:
language = args.language
if args["language"] in languages:
language = args["language"]
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language)
return {"result": "success", "data": token}
@ -86,34 +61,40 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args.email
user_email = args["email"]
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"])
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
token_data = AccountService.get_email_register_data(args.token)
token_data = AccountService.get_email_register_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args.email)
if args["code"] != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args["email"])
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_email_register_token(args.token)
AccountService.revoke_email_register_token(args["token"])
# Refresh token data by generating a new token
_, new_token = AccountService.generate_email_register_token(
user_email, code=args.code, additional_data={"phase": "register"}
user_email, code=args["code"], additional_data={"phase": "register"}
)
AccountService.reset_email_register_error_rate_limit(args.email)
AccountService.reset_email_register_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -123,14 +104,20 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled
@email_register_enabled
def post(self):
args = EmailRegisterResetPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match
if args.new_password != args.password_confirm:
if args["new_password"] != args["password_confirm"]:
raise PasswordMismatchError()
# Validate token and get register data
register_data = AccountService.get_email_register_data(args.token)
register_data = AccountService.get_email_register_data(args["token"])
if not register_data:
raise InvalidTokenError()
# Must use token in reset phase
@ -138,7 +125,7 @@ class EmailRegisterResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_email_register_token(args.token)
AccountService.revoke_email_register_token(args["token"])
email = register_data.get("email", "")
@ -148,7 +135,7 @@ class EmailRegisterResetApi(Resource):
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(email, args.password_confirm)
account = self._create_new_account(email, args["password_confirm"])
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))

View File

@ -2,8 +2,7 @@ import base64
import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -19,46 +18,26 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource):
@console_ns.doc("send_forgot_password_email")
@console_ns.doc(description="Send password reset email")
@console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
@console_ns.expect(
console_ns.model(
"ForgotPasswordEmailRequest",
{
"email": fields.String(required=True, description="Email address"),
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
},
)
)
@console_ns.response(
200,
"Email sent successfully",
@ -75,23 +54,28 @@ class ForgotPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args.language is not None and args.language == "zh-Hans":
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = AccountService.send_reset_password_email(
account=account,
email=args.email,
email=args["email"],
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@ -103,7 +87,16 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource):
@console_ns.doc("check_forgot_password_code")
@console_ns.doc(description="Verify password reset code")
@console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
@console_ns.expect(
console_ns.model(
"ForgotPasswordCheckRequest",
{
"email": fields.String(required=True, description="Email address"),
"code": fields.String(required=True, description="Verification code"),
"token": fields.String(required=True, description="Reset token"),
},
)
)
@console_ns.response(
200,
"Code verified successfully",
@ -120,34 +113,40 @@ class ForgotPasswordCheckApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args.email
user_email = args["email"]
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args.token)
token_data = AccountService.get_reset_password_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args.email)
if args["code"] != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"])
raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args.token)
AccountService.revoke_reset_password_token(args["token"])
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args.code, additional_data={"phase": "reset"}
user_email, code=args["code"], additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args.email)
AccountService.reset_forgot_password_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -155,7 +154,16 @@ class ForgotPasswordCheckApi(Resource):
class ForgotPasswordResetApi(Resource):
@console_ns.doc("reset_password")
@console_ns.doc(description="Reset password with verification token")
@console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
@console_ns.expect(
console_ns.model(
"ForgotPasswordResetRequest",
{
"token": fields.String(required=True, description="Verification token"),
"new_password": fields.String(required=True, description="New password"),
"password_confirm": fields.String(required=True, description="Password confirmation"),
},
)
)
@console_ns.response(
200,
"Password reset successfully",
@ -165,14 +173,20 @@ class ForgotPasswordResetApi(Resource):
@setup_required
@email_password_login_enabled
def post(self):
args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match
if args.new_password != args.password_confirm:
if args["new_password"] != args["password_confirm"]:
raise PasswordMismatchError()
# Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args.token)
reset_data = AccountService.get_reset_password_data(args["token"])
if not reset_data:
raise InvalidTokenError()
# Must use token in reset phase
@ -180,11 +194,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args.token)
AccountService.revoke_reset_password_token(args["token"])
# Generate secure salt and hash password
salt = secrets.token_bytes(16)
password_hashed = hash_password(args.new_password, salt)
password_hashed = hash_password(args["new_password"], salt)
email = reset_data.get("email", "")

View File

@ -1,7 +1,6 @@
import flask_login
from flask import make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from flask_restx import Resource, reqparse
import services
from configs import dify_config
@ -22,14 +21,9 @@ from controllers.console.error import (
NotAllowedCreateWorkspace,
WorkspacesLimitExceeded,
)
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
email_password_login_enabled,
setup_required,
)
from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import email, extract_remote_ip
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
@ -46,36 +40,6 @@ from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
class EmailPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class EmailCodeLoginPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
language: str | None = Field(default=None)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(LoginPayload)
reg(EmailPayload)
reg(EmailCodeLoginPayload)
@console_ns.route("/login")
class LoginApi(Resource):
@ -83,37 +47,41 @@ class LoginApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("password", type=str, required=True, location="json")
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
.add_argument("invite_token", type=str, required=False, default=None, location="json")
)
args = parser.parse_args()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
# TODO: why invitation is re-assigned with different type?
invitation = args.invite_token # type: ignore
invitation = args["invite_token"]
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
try:
if invitation:
data = invitation.get("data", {}) # type: ignore
data = invitation.get("data", {})
invitee_email = data.get("email") if data else None
if invitee_email != args.email:
if invitee_email != args["email"]:
raise InvalidEmailError()
account = AccountService.authenticate(args.email, args.password, args.invite_token)
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
else:
account = AccountService.authenticate(args.email, args.password)
account = AccountService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args.email)
AccountService.add_login_error_rate_limit(args["email"])
raise AuthenticationFailedError()
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
@ -129,7 +97,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args.email)
AccountService.reset_login_error_rate_limit(args["email"])
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@ -166,21 +134,25 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource):
@setup_required
@email_password_login_enabled
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
if args.language is not None and args.language == "zh-Hans":
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args.email)
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
email=args.email,
email=args["email"],
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@ -192,26 +164,30 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if args.language is not None and args.language == "zh-Hans":
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args.email)
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args.email, language=language)
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
else:
raise AccountNotFound()
else:
@ -223,25 +199,30 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource):
@setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
@decrypt_code_field
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
parser = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
user_email = args.email
language = args.language
user_email = args["email"]
language = args["language"]
token_data = AccountService.get_email_code_login_data(args.token)
token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args.email:
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args.code:
if token_data["code"] != args["code"]:
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args.token)
AccountService.revoke_email_code_login_token(args["token"])
try:
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError:
@ -274,7 +255,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args.email)
AccountService.reset_login_error_rate_limit(args["email"])
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})

View File

@ -3,8 +3,7 @@ from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import jsonify, request
from flask_restx import Resource
from pydantic import BaseModel
from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
@ -21,34 +20,15 @@ R = TypeVar("R")
T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
class OAuthProviderRequest(BaseModel):
client_id: str
redirect_uri: str
class OAuthTokenRequest(BaseModel):
client_id: str
grant_type: str
code: str | None = None
client_secret: str | None = None
redirect_uri: str | None = None
refresh_token: str | None = None
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
json_data = request.get_json()
if json_data is None:
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id")
if not client_id:
raise BadRequest("client_id is required")
payload = OAuthClientPayload.model_validate(json_data)
client_id = payload.client_id
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app:
raise NotFound("client_id is invalid")
@ -109,8 +89,9 @@ class OAuthServerAppApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
payload = OAuthProviderRequest.model_validate(request.get_json())
redirect_uri = payload.redirect_uri
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri")
# check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris:
@ -149,25 +130,33 @@ class OAuthServerUserTokenApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
payload = OAuthTokenRequest.model_validate(request.get_json())
parser = (
reqparse.RequestParser()
.add_argument("grant_type", type=str, required=True, location="json")
.add_argument("code", type=str, required=False, location="json")
.add_argument("client_secret", type=str, required=False, location="json")
.add_argument("redirect_uri", type=str, required=False, location="json")
.add_argument("refresh_token", type=str, required=False, location="json")
)
parsed_args = parser.parse_args()
try:
grant_type = OAuthGrantType(payload.grant_type)
grant_type = OAuthGrantType(parsed_args["grant_type"])
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
if not parsed_args["code"]:
raise BadRequest("code is required")
if payload.client_secret != oauth_provider_app.client_secret:
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
@ -178,11 +167,11 @@ class OAuthServerUserTokenApi(Resource):
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
if not parsed_args["refresh_token"]:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{

View File

@ -1,9 +1,6 @@
import base64
from typing import Literal
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns
@ -12,21 +9,6 @@ from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
interval: Literal["month", "year"] = Field(..., description="Billing interval")
class PartnerTenantsPayload(BaseModel):
click_id: str = Field(..., description="Click Id from partner referral link")
for model in (SubscriptionQuery, PartnerTenantsPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/billing/subscription")
class Subscription(Resource):
@ -36,9 +18,20 @@ class Subscription(Resource):
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument(
"plan",
type=str,
required=True,
location="args",
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
)
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices")
@ -72,10 +65,11 @@ class PartnerTenants(Resource):
@only_edition_cloud
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
args = parser.parse_args()
try:
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id
click_id = args["click_id"]
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
except Exception:
raise BadRequest("Invalid partner_key")

View File

@ -1,6 +1,5 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from flask_restx import Resource, reqparse
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
@ -10,28 +9,16 @@ from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
class ComplianceDownloadQuery(BaseModel):
doc_name: str = Field(..., description="Compliance document name")
console_ns.schema_model(
ComplianceDownloadQuery.__name__,
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/compliance/download")
class ComplianceApi(Resource):
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
@console_ns.doc("download_compliance_document")
@console_ns.doc(description="Get compliance document download link")
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device")

View File

@ -1,15 +1,15 @@
import json
from collections.abc import Generator
from typing import Any, cast
from typing import cast
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
@ -25,19 +25,6 @@ from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
from .. import console_ns
from ..wraps import account_initialization_required, setup_required
class NotionEstimatePayload(BaseModel):
notion_info_list: list[dict[str, Any]]
process_rule: dict[str, Any]
doc_form: str = Field(default="text_model")
doc_language: str = Field(default="English")
register_schema_model(console_ns, NotionEstimatePayload)
@console_ns.route(
"/data-source/integrates",
@ -140,18 +127,6 @@ class DataSourceNotionListApi(Resource):
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
datasource_parameters = {}
if datasource_parameters_str:
try:
datasource_parameters = json.loads(datasource_parameters_str)
if not isinstance(datasource_parameters, dict):
raise ValueError("datasource_parameters must be a JSON object.")
except json.JSONDecodeError:
raise ValueError("Invalid datasource_parameters JSON format.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
@ -199,7 +174,7 @@ class DataSourceNotionListApi(Resource):
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters=datasource_parameters,
datasource_parameters={},
provider_type=datasource_runtime.datasource_provider_type(),
)
)
@ -230,14 +205,14 @@ class DataSourceNotionListApi(Resource):
@console_ns.route(
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate",
)
class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, page_id, page_type):
def get(self, workspace_id, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
@ -251,10 +226,11 @@ class DataSourceNotionApi(Resource):
plugin_id="langgenius/notion_datasource",
)
workspace_id = str(workspace_id)
page_id = str(page_id)
extractor = NotionExtractor(
notion_workspace_id="",
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
@ -267,15 +243,20 @@ class DataSourceNotionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
def post(self):
_, current_tenant_id = current_account_with_tenant()
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
parser = (
reqparse.RequestParser()
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
notion_info_list = payload.notion_info_list
notion_info_list = args["notion_info_list"]
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]

View File

@ -1,14 +1,12 @@
from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import (
api_key_item_model,
@ -50,6 +48,7 @@ from fields.dataset_fields import (
)
from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
@ -108,75 +107,10 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_indexing_technique(value: str | None) -> str | None:
if value is None:
return value
if value not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Invalid indexing technique.")
return value
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field("", max_length=400)
indexing_technique: str | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
provider: str = "vendor"
external_knowledge_api_id: str | None = None
external_knowledge_id: str | None = None
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
@field_validator("provider")
@classmethod
def validate_provider(cls, value: str) -> str:
if value not in Dataset.PROVIDER_LIST:
raise ValueError("Invalid provider.")
return value
class DatasetUpdatePayload(BaseModel):
name: str | None = Field(None, min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
permission: DatasetPermissionEnum | None = None
indexing_technique: str | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
icon_info: dict[str, Any] | None = None
is_multimodal: bool | None = False
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
class IndexingEstimatePayload(BaseModel):
info_list: dict[str, Any]
process_rule: dict[str, Any]
indexing_technique: str
doc_form: str = "text_model"
dataset_id: str | None = None
doc_language: str = "English"
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str) -> str:
result = _validate_indexing_technique(value)
if result is None:
raise ValueError("indexing_technique is required.")
return result
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -223,7 +157,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.COUCHBASE,
VectorType.OPENGAUSS,
VectorType.OCEANBASE,
VectorType.SEEKDB,
VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD,
VectorType.TENCENT,
@ -231,7 +164,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.CLICKZETTA,
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
VectorType.IRIS,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
@ -323,7 +255,20 @@ class DatasetListApi(Resource):
@console_ns.doc("create_dataset")
@console_ns.doc(description="Create a new dataset")
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
@console_ns.expect(
console_ns.model(
"CreateDatasetRequest",
{
"name": fields.String(required=True, description="Dataset name (1-40 characters)"),
"description": fields.String(description="Dataset description (max 400 characters)"),
"indexing_technique": fields.String(description="Indexing technique"),
"permission": fields.String(description="Dataset permission"),
"provider": fields.String(description="Provider"),
"external_knowledge_api_id": fields.String(description="External knowledge API ID"),
"external_knowledge_id": fields.String(description="External knowledge ID"),
},
)
)
@console_ns.response(201, "Dataset created successfully")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@ -331,7 +276,52 @@ class DatasetListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
)
.add_argument(
"provider",
type=str,
nullable=True,
choices=Dataset.PROVIDER_LIST,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -341,14 +331,14 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_tenant_id,
name=payload.name,
description=payload.description,
indexing_technique=payload.indexing_technique,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=current_user,
permission=payload.permission or DatasetPermissionEnum.ONLY_ME,
provider=payload.provider,
external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=payload.external_knowledge_id,
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"],
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -409,7 +399,18 @@ class DatasetApi(Resource):
@console_ns.doc("update_dataset")
@console_ns.doc(description="Update dataset details")
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
@console_ns.expect(
console_ns.model(
"UpdateDatasetRequest",
{
"name": fields.String(description="Dataset name"),
"description": fields.String(description="Dataset description"),
"permission": fields.String(description="Dataset permission"),
"indexing_technique": fields.String(description="Indexing technique"),
"external_retrieval_model": fields.Raw(description="External retrieval model settings"),
},
)
)
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@ -423,25 +424,93 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(
DatasetPermissionEnum.ONLY_ME,
DatasetPermissionEnum.ALL_TEAM,
DatasetPermissionEnum.PARTIAL_TEAM,
),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
)
args = parser.parse_args()
data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
payload.indexing_technique == "high_quality"
and payload.embedding_model_provider is not None
and payload.embedding_model is not None
data.get("indexing_technique") == "high_quality"
and data.get("embedding_model_provider") is not None
and data.get("embedding_model") is not None
):
is_multimodal = DatasetService.check_is_multimodal_model(
dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
payload.is_multimodal = is_multimodal
payload_data = payload.model_dump(exclude_unset=True)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, payload.permission, payload.partial_member_list
current_user, dataset, data.get("permission"), data.get("partial_member_list")
)
dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
@ -449,10 +518,15 @@ class DatasetApi(Resource):
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_tenant_id
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -541,10 +615,24 @@ class DatasetIndexingEstimateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
def post(self):
payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
parser = (
reqparse.RequestParser()
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument(
"indexing_technique",
type=str,
required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
location="json",
)
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
# validate args
DocumentService.estimate_args_validate(args)

View File

@ -6,14 +6,31 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
@ -38,30 +55,10 @@ from fields.document_fields import (
)
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from ..datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from ..wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
logger = logging.getLogger(__name__)
@ -96,24 +93,6 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
document_ids: list[str]
class DocumentRenamePayload(BaseModel):
name: str
register_schema_models(
console_ns,
KnowledgeConfig,
ProcessRule,
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
)
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
@ -222,9 +201,8 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
def get(self, dataset_id: str):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
@ -332,7 +310,6 @@ class DatasetDocumentListApi(Resource):
@marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
@ -351,7 +328,23 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
.add_argument("data_source", type=dict, required=False, location="json")
.add_argument("process_rule", type=dict, required=False, location="json")
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
.add_argument("original_document_id", type=str, required=False, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.")
@ -397,7 +390,17 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource):
@console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.expect(
console_ns.model(
"DatasetInitRequest",
{
"upload_file_id": fields.String(required=True, description="Upload file ID"),
"indexing_technique": fields.String(description="Indexing technique"),
"process_rule": fields.Raw(description="Processing rules"),
"data_source": fields.Raw(description="Data source configuration"),
},
)
)
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
@console_ns.response(400, "Invalid request parameters")
@setup_required
@ -412,7 +415,27 @@ class DatasetInitApi(Resource):
if not current_user.is_dataset_editor:
raise Forbidden()
knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument(
"indexing_technique",
type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
required=True,
nullable=False,
location="json",
)
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@ -420,14 +443,10 @@ class DatasetInitApi(Resource):
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_tenant_id,
provider=knowledge_config.embedding_model_provider,
provider=args["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_config.embedding_model,
model=args["embedding_model"],
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal
except InvokeAuthorizationError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
@ -572,7 +591,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
@ -1057,16 +1076,19 @@ class DocumentRetryApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
def post(self, dataset_id):
"""retry document."""
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
parser = reqparse.RequestParser().add_argument(
"document_ids", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
retry_documents = []
if not dataset:
raise NotFound("Dataset not found.")
for document_id in payload.document_ids:
for document_id in args["document_ids"]:
try:
document_id = str(document_id)
@ -1099,7 +1121,6 @@ class DocumentRenameApi(DocumentResource):
@login_required
@account_initialization_required
@marshal_with(document_fields)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
@ -1109,10 +1130,11 @@ class DocumentRenameApi(DocumentResource):
if not dataset:
raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(current_user, dataset)
payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
document = DocumentService.rename_document(dataset_id, document_id, payload.name)
document = DocumentService.rename_document(dataset_id, document_id, args["name"])
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")

View File

@ -1,13 +1,11 @@
import uuid
from flask import request
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from flask_restx import Resource, marshal, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import (
@ -38,58 +36,6 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
hit_count_gte: int | None = None
enabled: str = Field(default="all")
keyword: str | None = None
page: int = Field(default=1, ge=1)
class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
class BatchImportPayload(BaseModel):
upload_file_id: str
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]
register_schema_models(
console_ns,
SegmentListQuery,
SegmentCreatePayload,
SegmentUpdatePayload,
BatchImportPayload,
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource):
@setup_required
@ -114,18 +60,23 @@ class DatasetDocumentSegmentListApi(Resource):
if not document:
raise NotFound("Document not found.")
args = SegmentListQuery.model_validate(
{
**request.args.to_dict(),
"status": request.args.getlist("status"),
}
parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("status", type=str, action="append", default=[], location="args")
.add_argument("hit_count_gte", type=int, default=None, location="args")
.add_argument("enabled", type=str, default="all", location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
page = args.page
limit = min(args.limit, 100)
status_list = args.status
hit_count_gte = args.hit_count_gte
keyword = args.keyword
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
status_list = args["status"]
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
query = (
select(DocumentSegment)
@ -145,10 +96,10 @@ class DatasetDocumentSegmentListApi(Resource):
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args.enabled.lower() != "all":
if args.enabled.lower() == "true":
if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true":
query = query.where(DocumentSegment.enabled == True)
elif args.enabled.lower() == "false":
elif args["enabled"].lower() == "false":
query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -259,7 +210,6 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -296,10 +246,15 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
parser = (
reqparse.RequestParser()
.add_argument("content", type=str, required=True, nullable=False, location="json")
.add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@ -310,7 +265,6 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -359,12 +313,18 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
parser = (
reqparse.RequestParser()
.add_argument("content", type=str, required=True, nullable=False, location="json")
.add_argument("answer", type=str, required=False, nullable=True, location="json")
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@ -417,7 +377,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -432,8 +391,11 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document:
raise NotFound("Document not found.")
payload = BatchImportPayload.model_validate(console_ns.payload or {})
upload_file_id = payload.upload_file_id
parser = reqparse.RequestParser().add_argument(
"upload_file_id", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
upload_file_id = args["upload_file_id"]
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file:
@ -484,7 +446,6 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -530,9 +491,13 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try:
payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
content = args["content"]
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@ -564,17 +529,18 @@ class ChildChunkAddApi(Resource):
)
if not segment:
raise NotFound("Segment not found.")
args = SegmentListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
page = args.page
limit = min(args.limit, 100)
keyword = args.keyword
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
@ -622,9 +588,14 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
parser = reqparse.RequestParser().add_argument(
"chunks", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try:
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
@ -694,7 +665,6 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
@ -741,9 +711,13 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try:
payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
content = args["content"]
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200

View File

@ -1,10 +1,8 @@
from flask import request
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field
from flask_restx import Resource, fields, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -73,38 +71,10 @@ except KeyError:
dataset_detail_model = _build_dataset_detail_model()
class ExternalKnowledgeApiPayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
settings: dict[str, object]
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
name: str = Field(..., min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None
class ExternalHitTestingPayload(BaseModel):
query: str
external_retrieval_model: dict[str, object] | None = None
metadata_filtering_conditions: dict[str, object] | None = None
class BedrockRetrievalPayload(BaseModel):
retrieval_setting: dict[str, object]
query: str
knowledge_id: str
register_schema_models(
console_ns,
ExternalKnowledgeApiPayload,
ExternalDatasetCreatePayload,
ExternalHitTestingPayload,
BedrockRetrievalPayload,
)
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 100:
raise ValueError("Name must be between 1 to 100 characters.")
return name
@console_ns.route("/datasets/external-knowledge-api")
@ -143,12 +113,28 @@ class ExternalApiTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(payload.settings)
ExternalDatasetService.validate_api_list(args["settings"])
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@ -156,7 +142,7 @@ class ExternalApiTemplateListApi(Resource):
try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump()
tenant_id=current_tenant_id, user_id=current_user.id, args=args
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -185,19 +171,35 @@ class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
ExternalDatasetService.validate_api_list(payload.settings)
parser = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id,
args=payload.model_dump(),
args=args,
)
return external_knowledge_api.to_dict(), 200
@ -238,7 +240,17 @@ class ExternalApiUseCheckApi(Resource):
class ExternalDatasetCreateApi(Resource):
@console_ns.doc("create_external_dataset")
@console_ns.doc(description="Create external knowledge dataset")
@console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__])
@console_ns.expect(
console_ns.model(
"CreateExternalDatasetRequest",
{
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
"external_knowledge_id": fields.String(required=True, description="External knowledge ID"),
"name": fields.String(required=True, description="Dataset name"),
"description": fields.String(description="Dataset description"),
},
)
)
@console_ns.response(201, "External dataset created successfully", dataset_detail_model)
@console_ns.response(400, "Invalid parameters")
@console_ns.response(403, "Permission denied")
@ -249,8 +261,22 @@ class ExternalDatasetCreateApi(Resource):
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant()
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
parser = (
reqparse.RequestParser()
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
.add_argument(
"name",
nullable=False,
required=True,
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument("description", type=str, required=False, nullable=True, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@ -273,7 +299,16 @@ class ExternalKnowledgeHitTestingApi(Resource):
@console_ns.doc("test_external_knowledge_retrieval")
@console_ns.doc(description="Test external knowledge retrieval for dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__])
@console_ns.expect(
console_ns.model(
"ExternalHitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"external_retrieval_model": fields.Raw(description="External retrieval model configuration"),
},
)
)
@console_ns.response(200, "External hit testing completed successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@ -292,16 +327,23 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {})
HitTestingService.hit_testing_args_check(payload.model_dump())
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
)
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
try:
response = HitTestingService.external_retrieve(
dataset=dataset,
query=payload.query,
query=args["query"],
account=current_user,
external_retrieval_model=payload.external_retrieval_model,
metadata_filtering_conditions=payload.metadata_filtering_conditions,
external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"],
)
return response
@ -314,13 +356,33 @@ class BedrockRetrievalApi(Resource):
# this api is only for internal testing
@console_ns.doc("bedrock_retrieval_test")
@console_ns.doc(description="Bedrock retrieval test (internal use only)")
@console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__])
@console_ns.expect(
console_ns.model(
"BedrockRetrievalTestRequest",
{
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
"query": fields.String(required=True, description="Query text"),
"knowledge_id": fields.String(required=True, description="Knowledge ID"),
},
)
)
@console_ns.response(200, "Bedrock retrieval test completed")
def post(self):
payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
.add_argument("knowledge_id", nullable=False, required=True, type=str)
)
args = parser.parse_args()
# Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval(
payload.retrieval_setting, payload.query, payload.knowledge_id
args["retrieval_setting"], args["query"], args["knowledge_id"]
)
return result, 200

View File

@ -1,17 +1,13 @@
from flask_restx import Resource
from flask_restx import Resource, fields
from controllers.common.schema import register_schema_model
from libs.login import login_required
from .. import console_ns
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from ..wraps import (
from controllers.console import console_ns
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
register_schema_model(console_ns, HitTestingPayload)
from libs.login import login_required
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@ -19,7 +15,17 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.expect(
console_ns.model(
"HitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"top_k": fields.Integer(description="Number of top results to return"),
"score_threshold": fields.Float(description="Score threshold for filtering results"),
},
)
)
@console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@ -31,8 +37,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
payload = HitTestingPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
args = self.parse_args()
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)

View File

@ -1,8 +1,6 @@
import logging
from typing import Any
from flask_restx import marshal, reqparse
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@ -29,13 +27,6 @@ from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
@ -52,15 +43,14 @@ class DatasetsHitTestingBase:
return dataset
@staticmethod
def hit_testing_args_check(args: dict[str, Any]):
def hit_testing_args_check(args):
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = (
reqparse.RequestParser()
.add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("query", type=str, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
@ -72,11 +62,10 @@ class DatasetsHitTestingBase:
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args.get("query"),
query=args["query"],
account=current_user,
retrieval_model=args.get("retrieval_model"),
external_retrieval_model=args.get("external_retrieval_model"),
attachment_ids=args.get("attachment_ids"),
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}

View File

@ -1,10 +1,8 @@
from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
@ -17,14 +15,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
register_schema_model(console_ns, MetadataUpdatePayload)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateApi(Resource):
@setup_required
@ -32,10 +22,15 @@ class DatasetMetadataCreateApi(Resource):
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -65,11 +60,11 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
name = payload.name
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
name = args["name"]
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@ -136,7 +131,6 @@ class DocumentMetadataEditApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -145,7 +139,11 @@ class DocumentMetadataEditApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
parser = reqparse.RequestParser().add_argument(
"operation_data", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -1,63 +1,20 @@
from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
class DatasourceCredentialPayload(BaseModel):
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any]
class DatasourceCredentialDeletePayload(BaseModel):
credential_id: str
class DatasourceCredentialUpdatePayload(BaseModel):
credential_id: str
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any] | None = None
class DatasourceCustomClientPayload(BaseModel):
client_params: dict[str, Any] | None = None
enable_oauth_custom_client: bool | None = None
class DatasourceDefaultPayload(BaseModel):
id: str
class DatasourceUpdateNamePayload(BaseModel):
credential_id: str
name: str = Field(max_length=100)
register_schema_models(
console_ns,
DatasourceCredentialPayload,
DatasourceCredentialDeletePayload,
DatasourceCredentialUpdatePayload,
DatasourceCustomClientPayload,
DatasourceDefaultPayload,
DatasourceUpdateNamePayload,
)
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@ -164,9 +121,16 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_datasource = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource):
@console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
@console_ns.expect(parser_datasource)
@setup_required
@login_required
@account_initialization_required
@ -174,7 +138,7 @@ class DatasourceAuth(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
args = parser_datasource.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -182,8 +146,8 @@ class DatasourceAuth(Resource):
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_tenant_id,
provider_id=datasource_provider_id,
credentials=payload.credentials,
name=payload.name,
credentials=args["credentials"],
name=args["name"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
@ -205,9 +169,14 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200
parser_datasource_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource):
@console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
@console_ns.expect(parser_datasource_delete)
@setup_required
@login_required
@account_initialization_required
@ -219,20 +188,28 @@ class DatasourceAuthDeleteApi(Resource):
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
args = parser_datasource_delete.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_tenant_id,
auth_id=payload.credential_id,
auth_id=args["credential_id"],
provider=provider_name,
plugin_id=plugin_id,
)
return {"result": "success"}, 200
parser_datasource_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource):
@console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
@console_ns.expect(parser_datasource_update)
@setup_required
@login_required
@account_initialization_required
@ -241,16 +218,16 @@ class DatasourceAuthUpdateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id)
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
args = parser_datasource_update.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_tenant_id,
auth_id=payload.credential_id,
auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
credentials=payload.credentials or {},
name=payload.name,
credentials=args.get("credentials", {}),
name=args.get("name", None),
)
return {"result": "success"}, 201
@ -281,9 +258,16 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200
parser_datasource_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource):
@console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
@console_ns.expect(parser_datasource_custom)
@setup_required
@login_required
@account_initialization_required
@ -291,14 +275,14 @@ class DatasourceAuthOauthCustomClient(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
args = parser_datasource_custom.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=payload.client_params or {},
enabled=payload.enable_oauth_custom_client or False,
client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False),
)
return {"result": "success"}, 200
@ -317,9 +301,12 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource):
@console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
@console_ns.expect(parser_default)
@setup_required
@login_required
@account_initialization_required
@ -327,20 +314,27 @@ class DatasourceAuthDefaultApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
args = parser_default.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=payload.id,
credential_id=args["id"],
)
return {"result": "success"}, 200
parser_update_name = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource):
@console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
@console_ns.expect(parser_update_name)
@setup_required
@login_required
@account_initialization_required
@ -348,13 +342,13 @@ class DatasourceUpdateProviderNameApi(Resource):
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
args = parser_update_name.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=payload.name,
credential_id=payload.credential_id,
name=args["name"],
credential_id=args["credential_id"],
)
return {"result": "success"}, 200

View File

@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource):
@console_ns.expect(console_ns.models[Parser.__name__])
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
@setup_required
@login_required
@account_initialization_required

Some files were not shown because too many files have changed in this diff Show More