Merge branch 'main' into feature/plugin-credential-deletion-option

This commit is contained in:
crazywoola 2025-12-26 11:15:04 +08:00 committed by GitHub
commit 29690062c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4378 changed files with 235254 additions and 106195 deletions

View File

@ -0,0 +1,483 @@
---
name: component-refactoring
description: Refactor high-complexity React components in Dify frontend. Use when `pnpm analyze-component --json` shows complexity > 50 or lineCount > 300, when the user asks for code splitting, hook extraction, or complexity reduction, or when `pnpm analyze-component` warns to refactor before testing; avoid for simple/well-structured components, third-party wrappers, or when the user explicitly wants testing without refactoring.
---
# Dify Component Refactoring Skill
Refactor high-complexity React components in the Dify frontend codebase with the patterns and workflow below.
> **Complexity Threshold**: Components with complexity > 50 (measured by `pnpm analyze-component`) should be refactored before testing.
## Quick Reference
### Commands (run from `web/`)
Use paths relative to `web/` (e.g., `app/components/...`).
Use `refactor-component` for refactoring prompts and `analyze-component` for testing prompts and metrics.
```bash
cd web
# Generate refactoring prompt
pnpm refactor-component <path>
# Output refactoring analysis as JSON
pnpm refactor-component <path> --json
# Generate testing prompt (after refactoring)
pnpm analyze-component <path>
# Output testing analysis as JSON
pnpm analyze-component <path> --json
```
### Complexity Analysis
```bash
# Analyze component complexity
pnpm analyze-component <path> --json
# Key metrics to check:
# - complexity: normalized score 0-100 (target < 50)
# - maxComplexity: highest single function complexity
# - lineCount: total lines (target < 300)
```
### Complexity Score Interpretation
| Score | Level | Action |
|-------|-------|--------|
| 0-25 | 🟢 Simple | Ready for testing |
| 26-50 | 🟡 Medium | Consider minor refactoring |
| 51-75 | 🟠 Complex | **Refactor before testing** |
| 76-100 | 🔴 Very Complex | **Must refactor** |
## Core Refactoring Patterns
### Pattern 1: Extract Custom Hooks
**When**: Component has complex state management, multiple `useState`/`useEffect`, or business logic mixed with UI.
**Dify Convention**: Place hooks in a `hooks/` subdirectory or alongside the component as `use-<feature>.ts`.
```typescript
// ❌ Before: Complex state logic in component
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [datasetConfigs, setDatasetConfigs] = useState<DatasetConfigs>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// 50+ lines of state management logic...
return <div>...</div>
}
// ✅ After: Extract to custom hook
// hooks/use-model-config.ts
export const useModelConfig = (appId: string) => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
// Related state management logic here
return { modelConfig, setModelConfig, completionParams, setCompletionParams }
}
// Component becomes cleaner
const Configuration: FC = () => {
const { modelConfig, setModelConfig } = useModelConfig(appId)
return <div>...</div>
}
```
**Dify Examples**:
- `web/app/components/app/configuration/hooks/use-advanced-prompt-config.ts`
- `web/app/components/app/configuration/debug/hooks.tsx`
- `web/app/components/workflow/hooks/use-workflow.ts`
### Pattern 2: Extract Sub-Components
**When**: Single component has multiple UI sections, conditional rendering blocks, or repeated patterns.
**Dify Convention**: Place sub-components in subdirectories or as separate files in the same directory.
```typescript
// ❌ Before: Monolithic JSX with multiple sections
const AppInfo = () => {
return (
<div>
{/* 100 lines of header UI */}
{/* 100 lines of operations UI */}
{/* 100 lines of modals */}
</div>
)
}
// ✅ After: Split into focused components
// app-info/
// ├── index.tsx (orchestration only)
// ├── app-header.tsx (header UI)
// ├── app-operations.tsx (operations UI)
// └── app-modals.tsx (modal management)
const AppInfo = () => {
const { showModal, setShowModal } = useAppInfoModals()
return (
<div>
<AppHeader appDetail={appDetail} />
<AppOperations onAction={handleAction} />
<AppModals show={showModal} onClose={() => setShowModal(null)} />
</div>
)
}
```
**Dify Examples**:
- `web/app/components/app/configuration/` directory structure
- `web/app/components/workflow/nodes/` per-node organization
### Pattern 3: Simplify Conditional Logic
**When**: Deep nesting (> 3 levels), complex ternaries, or multiple `if/else` chains.
```typescript
// ❌ Before: Deeply nested conditionals
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh />
case LanguagesSupported[7]:
return <TemplateChatJa />
default:
return <TemplateChatEn />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
// Another 15 lines...
}
// More conditions...
}, [appDetail, locale])
// ✅ After: Use lookup tables + early returns
const TEMPLATE_MAP = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
// ...
},
}
const Template = useMemo(() => {
const modeTemplates = TEMPLATE_MAP[appDetail?.mode]
if (!modeTemplates) return null
const TemplateComponent = modeTemplates[locale] || modeTemplates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
### Pattern 4: Extract API/Data Logic
**When**: Component directly handles API calls, data transformation, or complex async operations.
**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks. Project is migrating from SWR to React Query.
```typescript
// ❌ Before: API logic in component
const MCPServiceCard = () => {
const [basicAppConfig, setBasicAppConfig] = useState({})
useEffect(() => {
if (isBasicApp && appId) {
(async () => {
const res = await fetchAppDetail({ url: '/apps', id: appId })
setBasicAppConfig(res?.model_config || {})
})()
}
}, [appId, isBasicApp])
// More API-related logic...
}
// ✅ After: Extract to data hook using React Query
// use-app-config.ts
import { useQuery } from '@tanstack/react-query'
import { get } from '@/service/base'
const NAME_SPACE = 'appConfig'
export const useAppConfig = (appId: string, isBasicApp: boolean) => {
return useQuery({
enabled: isBasicApp && !!appId,
queryKey: [NAME_SPACE, 'detail', appId],
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || {},
})
}
// Component becomes cleaner
const MCPServiceCard = () => {
const { data: config, isLoading } = useAppConfig(appId, isBasicApp)
// UI only
}
```
**React Query Best Practices in Dify**:
- Define `NAME_SPACE` for query key organization
- Use `enabled` option for conditional fetching
- Use `select` for data transformation
- Export invalidation hooks: `useInvalidXxx`
**Dify Examples**:
- `web/service/use-workflow.ts`
- `web/service/use-common.ts`
- `web/service/knowledge/use-dataset.ts`
- `web/service/knowledge/use-document.ts`
### Pattern 5: Extract Modal/Dialog Management
**When**: Component manages multiple modals with complex open/close states.
**Dify Convention**: Modals should be extracted with their state management.
```typescript
// ❌ Before: Multiple modal states in component
const AppInfo = () => {
const [showEditModal, setShowEditModal] = useState(false)
const [showDuplicateModal, setShowDuplicateModal] = useState(false)
const [showConfirmDelete, setShowConfirmDelete] = useState(false)
const [showSwitchModal, setShowSwitchModal] = useState(false)
const [showImportDSLModal, setShowImportDSLModal] = useState(false)
// 5+ more modal states...
}
// ✅ After: Extract to modal management hook
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'import' | null
const useAppInfoModals = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const openModal = useCallback((type: ModalType) => setActiveModal(type), [])
const closeModal = useCallback(() => setActiveModal(null), [])
return {
activeModal,
openModal,
closeModal,
isOpen: (type: ModalType) => activeModal === type,
}
}
```
### Pattern 6: Extract Form Logic
**When**: Complex form validation, submission handling, or field transformation.
**Dify Convention**: Use `@tanstack/react-form` patterns from `web/app/components/base/form/`.
```typescript
// ✅ Use existing form infrastructure
import { useAppForm } from '@/app/components/base/form'
const ConfigForm = () => {
const form = useAppForm({
defaultValues: { name: '', description: '' },
onSubmit: handleSubmit,
})
return <form.Provider>...</form.Provider>
}
```
## Dify-Specific Refactoring Guidelines
### 1. Context Provider Extraction
**When**: Component provides complex context values with multiple states.
```typescript
// ❌ Before: Large context value object
const value = {
appId, isAPIKeySet, isTrailFinished, mode, modelModeType,
promptMode, isAdvancedMode, isAgent, isOpenAI, isFunctionCall,
// 50+ more properties...
}
return <ConfigContext.Provider value={value}>...</ConfigContext.Provider>
// ✅ After: Split into domain-specific contexts
<ModelConfigProvider value={modelConfigValue}>
<DatasetConfigProvider value={datasetConfigValue}>
<UIConfigProvider value={uiConfigValue}>
{children}
</UIConfigProvider>
</DatasetConfigProvider>
</ModelConfigProvider>
```
**Dify Reference**: `web/context/` directory structure
### 2. Workflow Node Components
**When**: Refactoring workflow node components (`web/app/components/workflow/nodes/`).
**Conventions**:
- Keep node logic in `use-interactions.ts`
- Extract panel UI to separate files
- Use `_base` components for common patterns
```
nodes/<node-type>/
├── index.tsx # Node registration
├── node.tsx # Node visual component
├── panel.tsx # Configuration panel
├── use-interactions.ts # Node-specific hooks
└── types.ts # Type definitions
```
### 3. Configuration Components
**When**: Refactoring app configuration components.
**Conventions**:
- Separate config sections into subdirectories
- Use existing patterns from `web/app/components/app/configuration/`
- Keep feature toggles in dedicated components
### 4. Tool/Plugin Components
**When**: Refactoring tool-related components (`web/app/components/tools/`).
**Conventions**:
- Follow existing modal patterns
- Use service hooks from `web/service/use-tools.ts`
- Keep provider-specific logic isolated
## Refactoring Workflow
### Step 1: Generate Refactoring Prompt
```bash
pnpm refactor-component <path>
```
This command will:
- Analyze component complexity and features
- Identify specific refactoring actions needed
- Generate a prompt for AI assistant (auto-copied to clipboard on macOS)
- Provide detailed requirements based on detected patterns
### Step 2: Analyze Details
```bash
pnpm analyze-component <path> --json
```
Identify:
- Total complexity score
- Max function complexity
- Line count
- Features detected (state, effects, API, etc.)
### Step 3: Plan
Create a refactoring plan based on detected features:
| Detected Feature | Refactoring Action |
|------------------|-------------------|
| `hasState: true` + `hasEffects: true` | Extract custom hook |
| `hasAPI: true` | Extract data/service hook |
| `hasEvents: true` (many) | Extract event handlers |
| `lineCount > 300` | Split into sub-components |
| `maxComplexity > 50` | Simplify conditional logic |
### Step 4: Execute Incrementally
1. **Extract one piece at a time**
2. **Run lint, type-check, and tests after each extraction**
3. **Verify functionality before next step**
```
For each extraction:
┌────────────────────────────────────────┐
│ 1. Extract code │
│ 2. Run: pnpm lint:fix │
│ 3. Run: pnpm type-check:tsgo │
│ 4. Run: pnpm test │
│ 5. Test functionality manually │
│ 6. PASS? → Next extraction │
│ FAIL? → Fix before continuing │
└────────────────────────────────────────┘
```
### Step 5: Verify
After refactoring:
```bash
# Re-run refactor command to verify improvements
pnpm refactor-component <path>
# If complexity < 25 and lines < 200, you'll see:
# ✅ COMPONENT IS WELL-STRUCTURED
# For detailed metrics:
pnpm analyze-component <path> --json
# Target metrics:
# - complexity < 50
# - lineCount < 300
# - maxComplexity < 30
```
## Common Mistakes to Avoid
### ❌ Over-Engineering
```typescript
// ❌ Too many tiny hooks
const useButtonText = () => useState('Click')
const useButtonDisabled = () => useState(false)
const useButtonLoading = () => useState(false)
// ✅ Cohesive hook with related state
const useButtonState = () => {
const [text, setText] = useState('Click')
const [disabled, setDisabled] = useState(false)
const [loading, setLoading] = useState(false)
return { text, setText, disabled, setDisabled, loading, setLoading }
}
```
### ❌ Breaking Existing Patterns
- Follow existing directory structures
- Maintain naming conventions
- Preserve export patterns for compatibility
### ❌ Premature Abstraction
- Only extract when there's clear complexity benefit
- Don't create abstractions for single-use code
- Keep refactored code in the same domain area
## References
### Dify Codebase Examples
- **Hook extraction**: `web/app/components/app/configuration/hooks/`
- **Component splitting**: `web/app/components/app/configuration/`
- **Service hooks**: `web/service/use-*.ts`
- **Workflow patterns**: `web/app/components/workflow/hooks/`
- **Form patterns**: `web/app/components/base/form/`
### Related Skills
- `frontend-testing` - For testing refactored components
- `web/testing/testing.md` - Testing specification

View File

@ -0,0 +1,493 @@
# Complexity Reduction Patterns
This document provides patterns for reducing cognitive complexity in Dify React components.
## Understanding Complexity
### SonarJS Cognitive Complexity
The `pnpm analyze-component` tool uses SonarJS cognitive complexity metrics:
- **Total Complexity**: Sum of all functions' complexity in the file
- **Max Complexity**: Highest single function complexity
### What Increases Complexity
| Pattern | Complexity Impact |
|---------|-------------------|
| `if/else` | +1 per branch |
| Nested conditions | +1 per nesting level |
| `switch/case` | +1 per case |
| `for/while/do` | +1 per loop |
| `&&`/`||` chains | +1 per operator |
| Nested callbacks | +1 per nesting level |
| `try/catch` | +1 per catch |
| Ternary expressions | +1 per nesting |
## Pattern 1: Replace Conditionals with Lookup Tables
**Before** (complexity: ~15):
```typescript
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateChatJa appDetail={appDetail} />
default:
return <TemplateChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateAdvancedChatZh appDetail={appDetail} />
case LanguagesSupported[7]:
return <TemplateAdvancedChatJa appDetail={appDetail} />
default:
return <TemplateAdvancedChatEn appDetail={appDetail} />
}
}
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
// Similar pattern...
}
return null
}, [appDetail, locale])
```
**After** (complexity: ~3):
```typescript
// Define lookup table outside component
const TEMPLATE_MAP: Record<AppModeEnum, Record<string, FC<TemplateProps>>> = {
[AppModeEnum.CHAT]: {
[LanguagesSupported[1]]: TemplateChatZh,
[LanguagesSupported[7]]: TemplateChatJa,
default: TemplateChatEn,
},
[AppModeEnum.ADVANCED_CHAT]: {
[LanguagesSupported[1]]: TemplateAdvancedChatZh,
[LanguagesSupported[7]]: TemplateAdvancedChatJa,
default: TemplateAdvancedChatEn,
},
[AppModeEnum.WORKFLOW]: {
[LanguagesSupported[1]]: TemplateWorkflowZh,
[LanguagesSupported[7]]: TemplateWorkflowJa,
default: TemplateWorkflowEn,
},
// ...
}
// Clean component logic
const Template = useMemo(() => {
if (!appDetail?.mode) return null
const templates = TEMPLATE_MAP[appDetail.mode]
if (!templates) return null
const TemplateComponent = templates[locale] ?? templates.default
return <TemplateComponent appDetail={appDetail} />
}, [appDetail, locale])
```
## Pattern 2: Use Early Returns
**Before** (complexity: ~10):
```typescript
const handleSubmit = () => {
if (isValid) {
if (hasChanges) {
if (isConnected) {
submitData()
} else {
showConnectionError()
}
} else {
showNoChangesMessage()
}
} else {
showValidationError()
}
}
```
**After** (complexity: ~4):
```typescript
const handleSubmit = () => {
if (!isValid) {
showValidationError()
return
}
if (!hasChanges) {
showNoChangesMessage()
return
}
if (!isConnected) {
showConnectionError()
return
}
submitData()
}
```
## Pattern 3: Extract Complex Conditions
**Before** (complexity: high):
```typescript
const canPublish = (() => {
if (mode !== AppModeEnum.COMPLETION) {
if (!isAdvancedMode)
return true
if (modelModeType === ModelModeType.completion) {
if (!hasSetBlockStatus.history || !hasSetBlockStatus.query)
return false
return true
}
return true
}
return !promptEmpty
})()
```
**After** (complexity: lower):
```typescript
// Extract to named functions
const canPublishInCompletionMode = () => !promptEmpty
const canPublishInChatMode = () => {
if (!isAdvancedMode) return true
if (modelModeType !== ModelModeType.completion) return true
return hasSetBlockStatus.history && hasSetBlockStatus.query
}
// Clean main logic
const canPublish = mode === AppModeEnum.COMPLETION
? canPublishInCompletionMode()
: canPublishInChatMode()
```
## Pattern 4: Replace Chained Ternaries
**Before** (complexity: ~5):
```typescript
const statusText = serverActivated
? t('status.running')
: serverPublished
? t('status.inactive')
: appUnpublished
? t('status.unpublished')
: t('status.notConfigured')
```
**After** (complexity: ~2):
```typescript
const getStatusText = () => {
if (serverActivated) return t('status.running')
if (serverPublished) return t('status.inactive')
if (appUnpublished) return t('status.unpublished')
return t('status.notConfigured')
}
const statusText = getStatusText()
```
Or use lookup:
```typescript
const STATUS_TEXT_MAP = {
running: 'status.running',
inactive: 'status.inactive',
unpublished: 'status.unpublished',
notConfigured: 'status.notConfigured',
} as const
const getStatusKey = (): keyof typeof STATUS_TEXT_MAP => {
if (serverActivated) return 'running'
if (serverPublished) return 'inactive'
if (appUnpublished) return 'unpublished'
return 'notConfigured'
}
const statusText = t(STATUS_TEXT_MAP[getStatusKey()])
```
## Pattern 5: Flatten Nested Loops
**Before** (complexity: high):
```typescript
const processData = (items: Item[]) => {
const results: ProcessedItem[] = []
for (const item of items) {
if (item.isValid) {
for (const child of item.children) {
if (child.isActive) {
for (const prop of child.properties) {
if (prop.value !== null) {
results.push({
itemId: item.id,
childId: child.id,
propValue: prop.value,
})
}
}
}
}
}
}
return results
}
```
**After** (complexity: lower):
```typescript
// Use functional approach
const processData = (items: Item[]) => {
return items
.filter(item => item.isValid)
.flatMap(item =>
item.children
.filter(child => child.isActive)
.flatMap(child =>
child.properties
.filter(prop => prop.value !== null)
.map(prop => ({
itemId: item.id,
childId: child.id,
propValue: prop.value,
}))
)
)
}
```
## Pattern 6: Extract Event Handler Logic
**Before** (complexity: high in component):
```typescript
const Component = () => {
const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
let newDatasets = data
if (data.find(item => !item.name)) {
const newSelected = produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const newItem = dataSets.find(i => i.id === item.id)
if (newItem)
draft[index] = newItem
}
})
})
setDataSets(newSelected)
newDatasets = newSelected
}
else {
setDataSets(data)
}
hideSelectDataSet()
// 40 more lines of logic...
}
return <div>...</div>
}
```
**After** (complexity: lower):
```typescript
// Extract to hook or utility
const useDatasetSelection = (dataSets: DataSet[], setDataSets: SetState<DataSet[]>) => {
const normalizeSelection = (data: DataSet[]) => {
const hasUnloadedItem = data.some(item => !item.name)
if (!hasUnloadedItem) return data
return produce(data, (draft) => {
data.forEach((item, index) => {
if (!item.name) {
const existing = dataSets.find(i => i.id === item.id)
if (existing) draft[index] = existing
}
})
})
}
const hasSelectionChanged = (newData: DataSet[]) => {
return !isEqual(
newData.map(item => item.id),
dataSets.map(item => item.id)
)
}
return { normalizeSelection, hasSelectionChanged }
}
// Component becomes cleaner
const Component = () => {
const { normalizeSelection, hasSelectionChanged } = useDatasetSelection(dataSets, setDataSets)
const handleSelect = (data: DataSet[]) => {
if (!hasSelectionChanged(data)) {
hideSelectDataSet()
return
}
formattingChangedDispatcher()
const normalized = normalizeSelection(data)
setDataSets(normalized)
hideSelectDataSet()
}
return <div>...</div>
}
```
## Pattern 7: Reduce Boolean Logic Complexity
**Before** (complexity: ~8):
```typescript
const toggleDisabled = hasInsufficientPermissions
|| appUnpublished
|| missingStartNode
|| triggerModeDisabled
|| (isAdvancedApp && !currentWorkflow?.graph)
|| (isBasicApp && !basicAppConfig.updated_at)
```
**After** (complexity: ~3):
```typescript
// Extract meaningful boolean functions
const isAppReady = () => {
if (isAdvancedApp) return !!currentWorkflow?.graph
return !!basicAppConfig.updated_at
}
const hasRequiredPermissions = () => {
return isCurrentWorkspaceEditor && !hasInsufficientPermissions
}
const canToggle = () => {
if (!hasRequiredPermissions()) return false
if (!isAppReady()) return false
if (missingStartNode) return false
if (triggerModeDisabled) return false
return true
}
const toggleDisabled = !canToggle()
```
## Pattern 8: Simplify useMemo/useCallback Dependencies
**Before** (complexity: multiple recalculations):
```typescript
const payload = useMemo(() => {
let parameters: Parameter[] = []
let outputParameters: OutputParameter[] = []
if (!published) {
parameters = (inputs || []).map((item) => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
outputParameters = (outputs || []).map((item) => ({
name: item.variable,
description: '',
type: item.value_type,
}))
}
else if (detail && detail.tool) {
parameters = (inputs || []).map((item) => ({
// Complex transformation...
}))
outputParameters = (outputs || []).map((item) => ({
// Complex transformation...
}))
}
return {
icon: detail?.icon || icon,
label: detail?.label || name,
// ...more fields
}
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
```
**After** (complexity: separated concerns):
```typescript
// Separate transformations
const useParameterTransform = (inputs: InputVar[], detail?: ToolDetail, published?: boolean) => {
return useMemo(() => {
if (!published) {
return inputs.map(item => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
}
if (!detail?.tool) return []
return inputs.map(item => ({
name: item.variable,
required: item.required,
type: item.type === 'paragraph' ? 'string' : item.type,
description: detail.tool.parameters.find(p => p.name === item.variable)?.llm_description || '',
form: detail.tool.parameters.find(p => p.name === item.variable)?.form || 'llm',
}))
}, [inputs, detail, published])
}
// Component uses hook
const parameters = useParameterTransform(inputs, detail, published)
const outputParameters = useOutputTransform(outputs, detail, published)
const payload = useMemo(() => ({
icon: detail?.icon || icon,
label: detail?.label || name,
parameters,
outputParameters,
// ...
}), [detail, icon, name, parameters, outputParameters])
```
## Target Metrics After Refactoring
| Metric | Target |
|--------|--------|
| Total Complexity | < 50 |
| Max Function Complexity | < 30 |
| Function Length | < 30 lines |
| Nesting Depth | ≤ 3 levels |
| Conditional Chains | ≤ 3 conditions |

View File

@ -0,0 +1,477 @@
# Component Splitting Patterns
This document provides detailed guidance on splitting large components into smaller, focused components in Dify.
## When to Split Components
Split a component when you identify:
1. **Multiple UI sections** - Distinct visual areas with minimal coupling that can be composed independently
1. **Conditional rendering blocks** - Large `{condition && <JSX />}` blocks
1. **Repeated patterns** - Similar UI structures used multiple times
1. **300+ lines** - Component exceeds manageable size
1. **Modal clusters** - Multiple modals rendered in one component
## Splitting Strategies
### Strategy 1: Section-Based Splitting
Identify visual sections and extract each as a component.
```typescript
// ❌ Before: Monolithic component (500+ lines)
const ConfigurationPage = () => {
return (
<div>
{/* Header Section - 50 lines */}
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher ... />
</div>
</div>
{/* Config Section - 200 lines */}
<div className="config">
<Config />
</div>
{/* Debug Section - 150 lines */}
<div className="debug">
<Debug ... />
</div>
{/* Modals Section - 100 lines */}
{showSelectDataSet && <SelectDataSet ... />}
{showHistoryModal && <EditHistoryModal ... />}
{showUseGPT4Confirm && <Confirm ... />}
</div>
)
}
// ✅ After: Split into focused components
// configuration/
// ├── index.tsx (orchestration)
// ├── configuration-header.tsx
// ├── configuration-content.tsx
// ├── configuration-debug.tsx
// └── configuration-modals.tsx
// configuration-header.tsx
interface ConfigurationHeaderProps {
isAdvancedMode: boolean
onPublish: () => void
}
const ConfigurationHeader: FC<ConfigurationHeaderProps> = ({
isAdvancedMode,
onPublish,
}) => {
const { t } = useTranslation()
return (
<div className="header">
<h1>{t('configuration.title')}</h1>
<div className="actions">
{isAdvancedMode && <Badge>Advanced</Badge>}
<ModelParameterModal ... />
<AppPublisher onPublish={onPublish} />
</div>
</div>
)
}
// index.tsx (orchestration only)
const ConfigurationPage = () => {
const { modelConfig, setModelConfig } = useModelConfig()
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
<ConfigurationHeader
isAdvancedMode={isAdvancedMode}
onPublish={handlePublish}
/>
<ConfigurationContent
modelConfig={modelConfig}
onConfigChange={setModelConfig}
/>
{!isMobile && (
<ConfigurationDebug
inputs={inputs}
onSetting={handleSetting}
/>
)}
<ConfigurationModals
activeModal={activeModal}
onClose={closeModal}
/>
</div>
)
}
```
### Strategy 2: Conditional Block Extraction
Extract large conditional rendering blocks.
```typescript
// ❌ Before: Large conditional blocks
const AppInfo = () => {
return (
<div>
{expand ? (
<div className="expanded">
{/* 100 lines of expanded view */}
</div>
) : (
<div className="collapsed">
{/* 50 lines of collapsed view */}
</div>
)}
</div>
)
}
// ✅ After: Separate view components
const AppInfoExpanded: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="expanded">
{/* Clean, focused expanded view */}
</div>
)
}
const AppInfoCollapsed: FC<AppInfoViewProps> = ({ appDetail, onAction }) => {
return (
<div className="collapsed">
{/* Clean, focused collapsed view */}
</div>
)
}
const AppInfo = () => {
return (
<div>
{expand
? <AppInfoExpanded appDetail={appDetail} onAction={handleAction} />
: <AppInfoCollapsed appDetail={appDetail} onAction={handleAction} />
}
</div>
)
}
```
### Strategy 3: Modal Extraction
Extract modals with their trigger logic.
```typescript
// ❌ Before: Multiple modals in one component
const AppInfo = () => {
const [showEdit, setShowEdit] = useState(false)
const [showDuplicate, setShowDuplicate] = useState(false)
const [showDelete, setShowDelete] = useState(false)
const [showSwitch, setShowSwitch] = useState(false)
const onEdit = async (data) => { /* 20 lines */ }
const onDuplicate = async (data) => { /* 20 lines */ }
const onDelete = async () => { /* 15 lines */ }
return (
<div>
{/* Main content */}
{showEdit && <EditModal onConfirm={onEdit} onClose={() => setShowEdit(false)} />}
{showDuplicate && <DuplicateModal onConfirm={onDuplicate} onClose={() => setShowDuplicate(false)} />}
{showDelete && <DeleteConfirm onConfirm={onDelete} onClose={() => setShowDelete(false)} />}
{showSwitch && <SwitchModal ... />}
</div>
)
}
// ✅ After: Modal manager component
// app-info-modals.tsx
type ModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | null
interface AppInfoModalsProps {
appDetail: AppDetail
activeModal: ModalType
onClose: () => void
onSuccess: () => void
}
const AppInfoModals: FC<AppInfoModalsProps> = ({
appDetail,
activeModal,
onClose,
onSuccess,
}) => {
const handleEdit = async (data) => { /* logic */ }
const handleDuplicate = async (data) => { /* logic */ }
const handleDelete = async () => { /* logic */ }
return (
<>
{activeModal === 'edit' && (
<EditModal
appDetail={appDetail}
onConfirm={handleEdit}
onClose={onClose}
/>
)}
{activeModal === 'duplicate' && (
<DuplicateModal
appDetail={appDetail}
onConfirm={handleDuplicate}
onClose={onClose}
/>
)}
{activeModal === 'delete' && (
<DeleteConfirm
onConfirm={handleDelete}
onClose={onClose}
/>
)}
{activeModal === 'switch' && (
<SwitchModal
appDetail={appDetail}
onClose={onClose}
/>
)}
</>
)
}
// Parent component
const AppInfo = () => {
const { activeModal, openModal, closeModal } = useModalState()
return (
<div>
{/* Main content with openModal triggers */}
<Button onClick={() => openModal('edit')}>Edit</Button>
<AppInfoModals
appDetail={appDetail}
activeModal={activeModal}
onClose={closeModal}
onSuccess={handleSuccess}
/>
</div>
)
}
```
### Strategy 4: List Item Extraction
Extract repeated item rendering.
```typescript
// ❌ Before: Inline item rendering
const OperationsList = () => {
return (
<div>
{operations.map(op => (
<div key={op.id} className="operation-item">
<span className="icon">{op.icon}</span>
<span className="title">{op.title}</span>
<span className="description">{op.description}</span>
<button onClick={() => op.onClick()}>
{op.actionLabel}
</button>
{op.badge && <Badge>{op.badge}</Badge>}
{/* More complex rendering... */}
</div>
))}
</div>
)
}
// ✅ After: Extracted item component
interface OperationItemProps {
operation: Operation
onAction: (id: string) => void
}
const OperationItem: FC<OperationItemProps> = ({ operation, onAction }) => {
return (
<div className="operation-item">
<span className="icon">{operation.icon}</span>
<span className="title">{operation.title}</span>
<span className="description">{operation.description}</span>
<button onClick={() => onAction(operation.id)}>
{operation.actionLabel}
</button>
{operation.badge && <Badge>{operation.badge}</Badge>}
</div>
)
}
const OperationsList = () => {
const handleAction = useCallback((id: string) => {
const op = operations.find(o => o.id === id)
op?.onClick()
}, [operations])
return (
<div>
{operations.map(op => (
<OperationItem
key={op.id}
operation={op}
onAction={handleAction}
/>
))}
</div>
)
}
```
## Directory Structure Patterns
### Pattern A: Flat Structure (Simple Components)
For components with 2-3 sub-components:
```
component-name/
├── index.tsx # Main component
├── sub-component-a.tsx
├── sub-component-b.tsx
└── types.ts # Shared types
```
### Pattern B: Nested Structure (Complex Components)
For components with many sub-components:
```
component-name/
├── index.tsx # Main orchestration
├── types.ts # Shared types
├── hooks/
│ ├── use-feature-a.ts
│ └── use-feature-b.ts
├── components/
│ ├── header/
│ │ └── index.tsx
│ ├── content/
│ │ └── index.tsx
│ └── modals/
│ └── index.tsx
└── utils/
└── helpers.ts
```
### Pattern C: Feature-Based Structure (Dify Standard)
Following Dify's existing patterns:
```
configuration/
├── index.tsx # Main page component
├── base/ # Base/shared components
│ ├── feature-panel/
│ ├── group-name/
│ └── operation-btn/
├── config/ # Config section
│ ├── index.tsx
│ ├── agent/
│ └── automatic/
├── dataset-config/ # Dataset section
│ ├── index.tsx
│ ├── card-item/
│ └── params-config/
├── debug/ # Debug section
│ ├── index.tsx
│ └── hooks.tsx
└── hooks/ # Shared hooks
└── use-advanced-prompt-config.ts
```
## Props Design
### Minimal Props Principle
Pass only what's needed:
```typescript
// ❌ Bad: Passing entire objects when only some fields needed
<ConfigHeader appDetail={appDetail} modelConfig={modelConfig} />
// ✅ Good: Destructure to minimum required
<ConfigHeader
appName={appDetail.name}
isAdvancedMode={modelConfig.isAdvanced}
onPublish={handlePublish}
/>
```
### Callback Props Pattern
Use callbacks for child-to-parent communication:
```typescript
// Parent
const Parent = () => {
const [value, setValue] = useState('')
return (
<Child
value={value}
onChange={setValue}
onSubmit={handleSubmit}
/>
)
}
// Child
interface ChildProps {
value: string
onChange: (value: string) => void
onSubmit: () => void
}
const Child: FC<ChildProps> = ({ value, onChange, onSubmit }) => {
return (
<div>
<input value={value} onChange={e => onChange(e.target.value)} />
<button onClick={onSubmit}>Submit</button>
</div>
)
}
```
### Render Props for Flexibility
When sub-components need parent context:
```typescript
interface ListProps<T> {
items: T[]
renderItem: (item: T, index: number) => React.ReactNode
renderEmpty?: () => React.ReactNode
}
function List<T>({ items, renderItem, renderEmpty }: ListProps<T>) {
if (items.length === 0 && renderEmpty) {
return <>{renderEmpty()}</>
}
return (
<div>
{items.map((item, index) => renderItem(item, index))}
</div>
)
}
// Usage
<List
items={operations}
renderItem={(op, i) => <OperationItem key={i} operation={op} />}
renderEmpty={() => <EmptyState message="No operations" />}
/>
```

View File

@ -0,0 +1,317 @@
# Hook Extraction Patterns
This document provides detailed guidance on extracting custom hooks from complex components in Dify.
## When to Extract Hooks
Extract a custom hook when you identify:
1. **Coupled state groups** - Multiple `useState` hooks that are always used together
1. **Complex effects** - `useEffect` with multiple dependencies or cleanup logic
1. **Business logic** - Data transformations, validations, or calculations
1. **Reusable patterns** - Logic that appears in multiple components
## Extraction Process
### Step 1: Identify State Groups
Look for state variables that are logically related:
```typescript
// ❌ These belong together - extract to hook
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
const [completionParams, setCompletionParams] = useState<FormValue>({})
const [modelModeType, setModelModeType] = useState<ModelModeType>(...)
// These are model-related state that should be in useModelConfig()
```
### Step 2: Identify Related Effects
Find effects that modify the grouped state:
```typescript
// ❌ These effects belong with the state above
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties.mode
if (mode) {
const newModelConfig = produce(modelConfig, (draft) => {
draft.mode = mode
})
setModelConfig(newModelConfig)
}
}
}, [textGenerationModelList, hasFetchedDetail, modelModeType, currModel])
```
### Step 3: Create the Hook
```typescript
// hooks/use-model-config.ts
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { ModelConfig } from '@/models/debug'
import { produce } from 'immer'
import { useEffect, useState } from 'react'
import { ModelModeType } from '@/types/app'
interface UseModelConfigParams {
initialConfig?: Partial<ModelConfig>
currModel?: { model_properties?: { mode?: ModelModeType } }
hasFetchedDetail: boolean
}
interface UseModelConfigReturn {
modelConfig: ModelConfig
setModelConfig: (config: ModelConfig) => void
completionParams: FormValue
setCompletionParams: (params: FormValue) => void
modelModeType: ModelModeType
}
export const useModelConfig = ({
initialConfig,
currModel,
hasFetchedDetail,
}: UseModelConfigParams): UseModelConfigReturn => {
const [modelConfig, setModelConfig] = useState<ModelConfig>({
provider: 'langgenius/openai/openai',
model_id: 'gpt-3.5-turbo',
mode: ModelModeType.unset,
// ... default values
...initialConfig,
})
const [completionParams, setCompletionParams] = useState<FormValue>({})
const modelModeType = modelConfig.mode
// Fill old app data missing model mode
useEffect(() => {
if (hasFetchedDetail && !modelModeType) {
const mode = currModel?.model_properties?.mode
if (mode) {
setModelConfig(produce(modelConfig, (draft) => {
draft.mode = mode
}))
}
}
}, [hasFetchedDetail, modelModeType, currModel])
return {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
}
}
```
### Step 4: Update Component
```typescript
// Before: 50+ lines of state management
const Configuration: FC = () => {
const [modelConfig, setModelConfig] = useState<ModelConfig>(...)
// ... lots of related state and effects
}
// After: Clean component
const Configuration: FC = () => {
const {
modelConfig,
setModelConfig,
completionParams,
setCompletionParams,
modelModeType,
} = useModelConfig({
currModel,
hasFetchedDetail,
})
// Component now focuses on UI
}
```
## Naming Conventions
### Hook Names
- Use `use` prefix: `useModelConfig`, `useDatasetConfig`
- Be specific: `useAdvancedPromptConfig` not `usePrompt`
- Include domain: `useWorkflowVariables`, `useMCPServer`
### File Names
- Kebab-case: `use-model-config.ts`
- Place in `hooks/` subdirectory when multiple hooks exist
- Place alongside component for single-use hooks
### Return Type Names
- Suffix with `Return`: `UseModelConfigReturn`
- Suffix params with `Params`: `UseModelConfigParams`
## Common Hook Patterns in Dify
### 1. Data Fetching Hook (React Query)
```typescript
// Pattern: Use @tanstack/react-query for data fetching
import { useQuery, useQueryClient } from '@tanstack/react-query'
import { get } from '@/service/base'
import { useInvalid } from '@/service/use-base'
const NAME_SPACE = 'appConfig'
// Query keys for cache management
export const appConfigQueryKeys = {
detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const,
}
// Main data hook
export const useAppConfig = (appId: string) => {
return useQuery({
enabled: !!appId,
queryKey: appConfigQueryKeys.detail(appId),
queryFn: () => get<AppDetailResponse>(`/apps/${appId}`),
select: data => data?.model_config || null,
})
}
// Invalidation hook for refreshing data
export const useInvalidAppConfig = () => {
return useInvalid([NAME_SPACE])
}
// Usage in component
const Component = () => {
const { data: config, isLoading, error, refetch } = useAppConfig(appId)
const invalidAppConfig = useInvalidAppConfig()
const handleRefresh = () => {
invalidAppConfig() // Invalidates cache and triggers refetch
}
return <div>...</div>
}
```
### 2. Form State Hook
```typescript
// Pattern: Form state + validation + submission
export const useConfigForm = (initialValues: ConfigFormValues) => {
const [values, setValues] = useState(initialValues)
const [errors, setErrors] = useState<Record<string, string>>({})
const [isSubmitting, setIsSubmitting] = useState(false)
const validate = useCallback(() => {
const newErrors: Record<string, string> = {}
if (!values.name) newErrors.name = 'Name is required'
setErrors(newErrors)
return Object.keys(newErrors).length === 0
}, [values])
const handleChange = useCallback((field: string, value: any) => {
setValues(prev => ({ ...prev, [field]: value }))
}, [])
const handleSubmit = useCallback(async (onSubmit: (values: ConfigFormValues) => Promise<void>) => {
if (!validate()) return
setIsSubmitting(true)
try {
await onSubmit(values)
} finally {
setIsSubmitting(false)
}
}, [values, validate])
return { values, errors, isSubmitting, handleChange, handleSubmit }
}
```
### 3. Modal State Hook
```typescript
// Pattern: Multiple modal management
type ModalType = 'edit' | 'delete' | 'duplicate' | null
export const useModalState = () => {
const [activeModal, setActiveModal] = useState<ModalType>(null)
const [modalData, setModalData] = useState<any>(null)
const openModal = useCallback((type: ModalType, data?: any) => {
setActiveModal(type)
setModalData(data)
}, [])
const closeModal = useCallback(() => {
setActiveModal(null)
setModalData(null)
}, [])
return {
activeModal,
modalData,
openModal,
closeModal,
isOpen: useCallback((type: ModalType) => activeModal === type, [activeModal]),
}
}
```
### 4. Toggle/Boolean Hook
```typescript
// Pattern: Boolean state with convenience methods
export const useToggle = (initialValue = false) => {
const [value, setValue] = useState(initialValue)
const toggle = useCallback(() => setValue(v => !v), [])
const setTrue = useCallback(() => setValue(true), [])
const setFalse = useCallback(() => setValue(false), [])
return [value, { toggle, setTrue, setFalse, set: setValue }] as const
}
// Usage
const [isExpanded, { toggle, setTrue: expand, setFalse: collapse }] = useToggle()
```
## Testing Extracted Hooks
After extraction, test hooks in isolation:
```typescript
// use-model-config.spec.ts
import { renderHook, act } from '@testing-library/react'
import { useModelConfig } from './use-model-config'
describe('useModelConfig', () => {
it('should initialize with default values', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: false,
}))
expect(result.current.modelConfig.provider).toBe('langgenius/openai/openai')
expect(result.current.modelModeType).toBe(ModelModeType.unset)
})
it('should update model config', () => {
const { result } = renderHook(() => useModelConfig({
hasFetchedDetail: true,
}))
act(() => {
result.current.setModelConfig({
...result.current.modelConfig,
model_id: 'gpt-4',
})
})
expect(result.current.modelConfig.model_id).toBe('gpt-4')
})
})
```

View File

@ -0,0 +1,322 @@
---
name: frontend-testing
description: Generate Vitest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Vitest, RTL, unit tests, integration tests, or write/review test requests.
---
# Dify Frontend Testing Skill
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
## When to Apply This Skill
Apply this skill when the user:
- Asks to **write tests** for a component, hook, or utility
- Asks to **review existing tests** for completeness
- Mentions **Vitest**, **React Testing Library**, **RTL**, or **spec files**
- Requests **test coverage** improvement
- Uses `pnpm analyze-component` output as context
- Mentions **testing**, **unit tests**, or **integration tests** for frontend code
- Wants to understand **testing patterns** in the Dify codebase
**Do NOT apply** when:
- User is asking about backend/API tests (Python/pytest)
- User is asking about E2E tests (Playwright/Cypress)
- User is only asking conceptual questions without code context
## Quick Reference
### Tech Stack
| Tool | Version | Purpose |
|------|---------|---------|
| Vitest | 4.0.16 | Test runner |
| React Testing Library | 16.0 | Component testing |
| jsdom | - | Test environment |
| nock | 14.0 | HTTP mocking |
| TypeScript | 5.x | Type safety |
### Key Commands
```bash
# Run all tests
pnpm test
# Watch mode
pnpm test:watch
# Run specific file
pnpm test path/to/file.spec.tsx
# Generate coverage report
pnpm test:coverage
# Analyze component complexity
pnpm analyze-component <path>
# Review existing test
pnpm analyze-component <path> --review
```
### File Naming
- Test files: `ComponentName.spec.tsx` (same directory as component)
- Integration tests: `web/__tests__/` directory
## Test Structure Template
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import Component from './index'
// ✅ Import real project components (DO NOT mock these)
// import Loading from '@/app/components/base/loading'
// import { ChildComponent } from './child-component'
// ✅ Mock external dependencies only
vi.mock('@/service/api')
vi.mock('next/navigation', () => ({
useRouter: () => ({ push: vi.fn() }),
usePathname: () => '/test',
}))
// Shared state for mocks (if needed)
let mockSharedState = false
describe('ComponentName', () => {
beforeEach(() => {
vi.clearAllMocks() // ✅ Reset mocks BEFORE each test
mockSharedState = false // ✅ Reset shared state
})
// Rendering tests (REQUIRED)
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange
const props = { title: 'Test' }
// Act
render(<Component {...props} />)
// Assert
expect(screen.getByText('Test')).toBeInTheDocument()
})
})
// Props tests (REQUIRED)
describe('Props', () => {
it('should apply custom className', () => {
render(<Component className="custom" />)
expect(screen.getByRole('button')).toHaveClass('custom')
})
})
// User Interactions
describe('User Interactions', () => {
it('should handle click events', () => {
const handleClick = vi.fn()
render(<Component onClick={handleClick} />)
fireEvent.click(screen.getByRole('button'))
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle null data', () => {
render(<Component data={null} />)
expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle empty array', () => {
render(<Component items={[]} />)
expect(screen.getByText(/empty/i)).toBeInTheDocument()
})
})
})
```
## Testing Workflow (CRITICAL)
### ⚠️ Incremental Approach Required
**NEVER generate all test files at once.** For complex components or multi-file directories:
1. **Analyze & Plan**: List all files, order by complexity (simple → complex)
1. **Process ONE at a time**: Write test → Run test → Fix if needed → Next
1. **Verify before proceeding**: Do NOT continue to next file until current passes
```
For each file:
┌────────────────────────────────────────┐
│ 1. Write test │
│ 2. Run: pnpm test <file>.spec.tsx │
│ 3. PASS? → Mark complete, next file │
│ FAIL? → Fix first, then continue │
└────────────────────────────────────────┘
```
### Complexity-Based Order
Process in this order for multi-file testing:
1. 🟢 Utility functions (simplest)
1. 🟢 Custom hooks
1. 🟡 Simple components (presentational)
1. 🟡 Medium components (state, effects)
1. 🔴 Complex components (API, routing)
1. 🔴 Integration tests (index files - last)
### When to Refactor First
- **Complexity > 50**: Break into smaller pieces before testing
- **500+ lines**: Consider splitting before testing
- **Many dependencies**: Extract logic into hooks first
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
## Testing Strategy
### Path-Level Testing (Directory Testing)
When assigned to test a directory/path, test **ALL content** within that path:
- Test all components, hooks, utilities in the directory (not just `index` file)
- Use incremental approach: one file at a time, verify each before proceeding
- Goal: 100% coverage of ALL files in the directory
### Integration Testing First
**Prefer integration testing** when writing tests for a directory:
- ✅ **Import real project components** directly (including base components and siblings)
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
- ❌ **DO NOT mock** sibling/child components in the same directory
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
## Core Principles
### 1. AAA Pattern (Arrange-Act-Assert)
Every test should clearly separate:
- **Arrange**: Setup test data and render component
- **Act**: Perform user actions
- **Assert**: Verify expected outcomes
### 2. Black-Box Testing
- Test observable behavior, not implementation details
- Use semantic queries (getByRole, getByLabelText)
- Avoid testing internal state directly
- **Prefer pattern matching over hardcoded strings** in assertions:
```typescript
// ❌ Avoid: hardcoded text assertions
expect(screen.getByText('Loading...')).toBeInTheDocument()
// ✅ Better: role-based queries
expect(screen.getByRole('status')).toBeInTheDocument()
// ✅ Better: pattern matching
expect(screen.getByText(/loading/i)).toBeInTheDocument()
```
### 3. Single Behavior Per Test
Each test verifies ONE user-observable behavior:
```typescript
// ✅ Good: One behavior
it('should disable button when loading', () => {
render(<Button loading />)
expect(screen.getByRole('button')).toBeDisabled()
})
// ❌ Bad: Multiple behaviors
it('should handle loading state', () => {
render(<Button loading />)
expect(screen.getByRole('button')).toBeDisabled()
expect(screen.getByText('Loading...')).toBeInTheDocument()
expect(screen.getByRole('button')).toHaveClass('loading')
})
```
### 4. Semantic Naming
Use `should <behavior> when <condition>`:
```typescript
it('should show error message when validation fails')
it('should call onSubmit when form is valid')
it('should disable input when isReadOnly is true')
```
## Required Test Scenarios
### Always Required (All Components)
1. **Rendering**: Component renders without crashing
1. **Props**: Required props, optional props, default values
1. **Edge Cases**: null, undefined, empty values, boundary conditions
### Conditional (When Present)
| Feature | Test Focus |
|---------|-----------|
| `useState` | Initial state, transitions, cleanup |
| `useEffect` | Execution, dependencies, cleanup |
| Event handlers | All onClick, onChange, onSubmit, keyboard |
| API calls | Loading, success, error states |
| Routing | Navigation, params, query strings |
| `useCallback`/`useMemo` | Referential equality |
| Context | Provider values, consumer behavior |
| Forms | Validation, submission, error display |
## Coverage Goals (Per File)
For each test file generated, aim for:
- ✅ **100%** function coverage
- ✅ **100%** statement coverage
- ✅ **>95%** branch coverage
- ✅ **>95%** line coverage
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
## Detailed Guides
For more detailed information, refer to:
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `references/mocking.md` - Mock patterns and best practices
- `references/async-testing.md` - Async operations and API calls
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
- `references/common-patterns.md` - Frequently used testing patterns
- `references/checklist.md` - Test generation checklist and validation steps
## Authoritative References
### Primary Specification (MUST follow)
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
### Reference Examples in Codebase
- `web/utils/classnames.spec.ts` - Utility function tests
- `web/app/components/base/button/index.spec.tsx` - Component tests
- `web/__mocks__/provider-context.ts` - Mock factory example
### Project Configuration
- `web/vitest.config.ts` - Vitest configuration
- `web/vitest.setup.ts` - Test environment setup
- `web/scripts/analyze-component.js` - Component analysis tool
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.

View File

@ -0,0 +1,296 @@
/**
* Test Template for React Components
*
* WHY THIS STRUCTURE?
* - Organized sections make tests easy to navigate and maintain
* - Mocks at top ensure consistent test isolation
* - Factory functions reduce duplication and improve readability
* - describe blocks group related scenarios for better debugging
*
* INSTRUCTIONS:
* 1. Replace `ComponentName` with your component name
* 2. Update import path
* 3. Add/remove test sections based on component features (use analyze-component)
* 4. Follow AAA pattern: Arrange Act Assert
*
* RUN FIRST: pnpm analyze-component <path> to identify required test scenarios
*/
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
// import ComponentName from './index'
// ============================================================================
// Mocks
// ============================================================================
// WHY: Mocks must be hoisted to top of file (Vitest requirement).
// They run BEFORE imports, so keep them before component imports.
// i18n (automatically mocked)
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
// No explicit mock needed - it returns translation keys as-is
// Override only if custom translations are required:
// vi.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
// 'my.custom.key': 'Custom Translation',
// }
// return customTranslations[key] || key
// },
// }),
// }))
// Router (if component uses useRouter, usePathname, useSearchParams)
// WHY: Isolates tests from Next.js routing, enables testing navigation behavior
// const mockPush = vi.fn()
// vi.mock('next/navigation', () => ({
// useRouter: () => ({ push: mockPush }),
// usePathname: () => '/test-path',
// }))
// API services (if component fetches data)
// WHY: Prevents real network calls, enables testing all states (loading/success/error)
// vi.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = vi.mocked(api)
// Shared mock state (for portal/dropdown components)
// WHY: Portal components like PortalToFollowElem need shared state between
// parent and child mocks to correctly simulate open/close behavior
// let mockOpenState = false
// ============================================================================
// Test Data Factories
// ============================================================================
// WHY FACTORIES?
// - Avoid hard-coded test data scattered across tests
// - Easy to create variations with overrides
// - Type-safe when using actual types from source
// - Single source of truth for default test values
// const createMockProps = (overrides = {}) => ({
// // Default props that make component render successfully
// ...overrides,
// })
// const createMockItem = (overrides = {}) => ({
// id: 'item-1',
// name: 'Test Item',
// ...overrides,
// })
// ============================================================================
// Test Helpers
// ============================================================================
// const renderComponent = (props = {}) => {
// return render(<ComponentName {...createMockProps(props)} />)
// }
// ============================================================================
// Tests
// ============================================================================
describe('ComponentName', () => {
// WHY beforeEach with clearAllMocks?
// - Ensures each test starts with clean slate
// - Prevents mock call history from leaking between tests
// - MUST be beforeEach (not afterEach) to reset BEFORE assertions like toHaveBeenCalledTimes
beforeEach(() => {
vi.clearAllMocks()
// Reset shared mock state if used (CRITICAL for portal/dropdown tests)
// mockOpenState = false
})
// --------------------------------------------------------------------------
// Rendering Tests (REQUIRED - Every component MUST have these)
// --------------------------------------------------------------------------
// WHY: Catches import errors, missing providers, and basic render issues
describe('Rendering', () => {
it('should render without crashing', () => {
// Arrange - Setup data and mocks
// const props = createMockProps()
// Act - Render the component
// render(<ComponentName {...props} />)
// Assert - Verify expected output
// Prefer getByRole for accessibility; it's what users "see"
// expect(screen.getByRole('...')).toBeInTheDocument()
})
it('should render with default props', () => {
// WHY: Verifies component works without optional props
// render(<ComponentName />)
// expect(screen.getByText('...')).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Props Tests (REQUIRED - Every component MUST test prop behavior)
// --------------------------------------------------------------------------
// WHY: Props are the component's API contract. Test them thoroughly.
describe('Props', () => {
it('should apply custom className', () => {
// WHY: Common pattern in Dify - components should merge custom classes
// render(<ComponentName className="custom-class" />)
// expect(screen.getByTestId('component')).toHaveClass('custom-class')
})
it('should use default values for optional props', () => {
// WHY: Verifies TypeScript defaults work at runtime
// render(<ComponentName />)
// expect(screen.getByRole('...')).toHaveAttribute('...', 'default-value')
})
})
// --------------------------------------------------------------------------
// User Interactions (if component has event handlers - on*, handle*)
// --------------------------------------------------------------------------
// WHY: Event handlers are core functionality. Test from user's perspective.
describe('User Interactions', () => {
it('should call onClick when clicked', async () => {
// WHY userEvent over fireEvent?
// - userEvent simulates real user behavior (focus, hover, then click)
// - fireEvent is lower-level, doesn't trigger all browser events
// const user = userEvent.setup()
// const handleClick = vi.fn()
// render(<ComponentName onClick={handleClick} />)
//
// await user.click(screen.getByRole('button'))
//
// expect(handleClick).toHaveBeenCalledTimes(1)
})
it('should call onChange when value changes', async () => {
// const user = userEvent.setup()
// const handleChange = vi.fn()
// render(<ComponentName onChange={handleChange} />)
//
// await user.type(screen.getByRole('textbox'), 'new value')
//
// expect(handleChange).toHaveBeenCalled()
})
})
// --------------------------------------------------------------------------
// State Management (if component uses useState/useReducer)
// --------------------------------------------------------------------------
// WHY: Test state through observable UI changes, not internal state values
describe('State Management', () => {
it('should update state on interaction', async () => {
// WHY test via UI, not state?
// - State is implementation detail; UI is what users see
// - If UI works correctly, state must be correct
// const user = userEvent.setup()
// render(<ComponentName />)
//
// // Initial state - verify what user sees
// expect(screen.getByText('Initial')).toBeInTheDocument()
//
// // Trigger state change via user action
// await user.click(screen.getByRole('button'))
//
// // New state - verify UI updated
// expect(screen.getByText('Updated')).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Async Operations (if component fetches data - useQuery, fetch)
// --------------------------------------------------------------------------
// WHY: Async operations have 3 states users experience: loading, success, error
describe('Async Operations', () => {
it('should show loading state', () => {
// WHY never-resolving promise?
// - Keeps component in loading state for assertion
// - Alternative: use fake timers
// mockedApi.fetchData.mockImplementation(() => new Promise(() => {}))
// render(<ComponentName />)
//
// expect(screen.getByText(/loading/i)).toBeInTheDocument()
})
it('should show data on success', async () => {
// WHY waitFor?
// - Component updates asynchronously after fetch resolves
// - waitFor retries assertion until it passes or times out
// mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
// render(<ComponentName />)
//
// await waitFor(() => {
// expect(screen.getByText('Item 1')).toBeInTheDocument()
// })
})
it('should show error on failure', async () => {
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
// render(<ComponentName />)
//
// await waitFor(() => {
// expect(screen.getByText(/error/i)).toBeInTheDocument()
// })
})
})
// --------------------------------------------------------------------------
// Edge Cases (REQUIRED - Every component MUST handle edge cases)
// --------------------------------------------------------------------------
// WHY: Real-world data is messy. Components must handle:
// - Null/undefined from API failures or optional fields
// - Empty arrays/strings from user clearing data
// - Boundary values (0, MAX_INT, special characters)
describe('Edge Cases', () => {
it('should handle null value', () => {
// WHY test null specifically?
// - API might return null for missing data
// - Prevents "Cannot read property of null" in production
// render(<ComponentName value={null} />)
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle undefined value', () => {
// WHY test undefined separately from null?
// - TypeScript treats them differently
// - Optional props are undefined, not null
// render(<ComponentName value={undefined} />)
// expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
it('should handle empty array', () => {
// WHY: Empty state often needs special UI (e.g., "No items yet")
// render(<ComponentName items={[]} />)
// expect(screen.getByText(/empty/i)).toBeInTheDocument()
})
it('should handle empty string', () => {
// WHY: Empty strings are truthy in JS but visually empty
// render(<ComponentName text="" />)
// expect(screen.getByText(/placeholder/i)).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Accessibility (optional but recommended for Dify's enterprise users)
// --------------------------------------------------------------------------
// WHY: Dify has enterprise customers who may require accessibility compliance
describe('Accessibility', () => {
it('should have accessible name', () => {
// WHY getByRole with name?
// - Tests that screen readers can identify the element
// - Enforces proper labeling practices
// render(<ComponentName label="Test Label" />)
// expect(screen.getByRole('button', { name: /test label/i })).toBeInTheDocument()
})
it('should support keyboard navigation', async () => {
// WHY: Some users can't use a mouse
// const user = userEvent.setup()
// render(<ComponentName />)
//
// await user.tab()
// expect(screen.getByRole('button')).toHaveFocus()
})
})
})

View File

@ -0,0 +1,207 @@
/**
* Test Template for Custom Hooks
*
* Instructions:
* 1. Replace `useHookName` with your hook name
* 2. Update import path
* 3. Add/remove test sections based on hook features
*/
import { renderHook, act, waitFor } from '@testing-library/react'
// import { useHookName } from './use-hook-name'
// ============================================================================
// Mocks
// ============================================================================
// API services (if hook fetches data)
// vi.mock('@/service/api')
// import * as api from '@/service/api'
// const mockedApi = vi.mocked(api)
// ============================================================================
// Test Helpers
// ============================================================================
// Wrapper for hooks that need context
// const createWrapper = (contextValue = {}) => {
// return ({ children }: { children: React.ReactNode }) => (
// <SomeContext.Provider value={contextValue}>
// {children}
// </SomeContext.Provider>
// )
// }
// ============================================================================
// Tests
// ============================================================================
describe('useHookName', () => {
beforeEach(() => {
vi.clearAllMocks()
})
// --------------------------------------------------------------------------
// Initial State
// --------------------------------------------------------------------------
describe('Initial State', () => {
it('should return initial state', () => {
// const { result } = renderHook(() => useHookName())
//
// expect(result.current.value).toBe(initialValue)
// expect(result.current.isLoading).toBe(false)
})
it('should accept initial value from props', () => {
// const { result } = renderHook(() => useHookName({ initialValue: 'custom' }))
//
// expect(result.current.value).toBe('custom')
})
})
// --------------------------------------------------------------------------
// State Updates
// --------------------------------------------------------------------------
describe('State Updates', () => {
it('should update value when setValue is called', () => {
// const { result } = renderHook(() => useHookName())
//
// act(() => {
// result.current.setValue('new value')
// })
//
// expect(result.current.value).toBe('new value')
})
it('should reset to initial value', () => {
// const { result } = renderHook(() => useHookName({ initialValue: 'initial' }))
//
// act(() => {
// result.current.setValue('changed')
// })
// expect(result.current.value).toBe('changed')
//
// act(() => {
// result.current.reset()
// })
// expect(result.current.value).toBe('initial')
})
})
// --------------------------------------------------------------------------
// Async Operations
// --------------------------------------------------------------------------
describe('Async Operations', () => {
it('should fetch data on mount', async () => {
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
//
// const { result } = renderHook(() => useHookName())
//
// // Initially loading
// expect(result.current.isLoading).toBe(true)
//
// // Wait for data
// await waitFor(() => {
// expect(result.current.isLoading).toBe(false)
// })
//
// expect(result.current.data).toEqual({ data: 'test' })
})
it('should handle fetch error', async () => {
// mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
//
// const { result } = renderHook(() => useHookName())
//
// await waitFor(() => {
// expect(result.current.error).toBeTruthy()
// })
//
// expect(result.current.error?.message).toBe('Network error')
})
it('should refetch when dependency changes', async () => {
// mockedApi.fetchData.mockResolvedValue({ data: 'test' })
//
// const { result, rerender } = renderHook(
// ({ id }) => useHookName(id),
// { initialProps: { id: '1' } }
// )
//
// await waitFor(() => {
// expect(mockedApi.fetchData).toHaveBeenCalledWith('1')
// })
//
// rerender({ id: '2' })
//
// await waitFor(() => {
// expect(mockedApi.fetchData).toHaveBeenCalledWith('2')
// })
})
})
// --------------------------------------------------------------------------
// Side Effects
// --------------------------------------------------------------------------
describe('Side Effects', () => {
it('should call callback when value changes', () => {
// const callback = vi.fn()
// const { result } = renderHook(() => useHookName({ onChange: callback }))
//
// act(() => {
// result.current.setValue('new value')
// })
//
// expect(callback).toHaveBeenCalledWith('new value')
})
it('should cleanup on unmount', () => {
// const cleanup = vi.fn()
// vi.spyOn(window, 'addEventListener')
// vi.spyOn(window, 'removeEventListener')
//
// const { unmount } = renderHook(() => useHookName())
//
// expect(window.addEventListener).toHaveBeenCalled()
//
// unmount()
//
// expect(window.removeEventListener).toHaveBeenCalled()
})
})
// --------------------------------------------------------------------------
// Edge Cases
// --------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should handle null input', () => {
// const { result } = renderHook(() => useHookName(null))
//
// expect(result.current.value).toBeNull()
})
it('should handle rapid updates', () => {
// const { result } = renderHook(() => useHookName())
//
// act(() => {
// result.current.setValue('1')
// result.current.setValue('2')
// result.current.setValue('3')
// })
//
// expect(result.current.value).toBe('3')
})
})
// --------------------------------------------------------------------------
// With Context (if hook uses context)
// --------------------------------------------------------------------------
describe('With Context', () => {
it('should use context value', () => {
// const wrapper = createWrapper({ someValue: 'context-value' })
// const { result } = renderHook(() => useHookName(), { wrapper })
//
// expect(result.current.contextValue).toBe('context-value')
})
})
})

View File

@ -0,0 +1,154 @@
/**
* Test Template for Utility Functions
*
* Instructions:
* 1. Replace `utilityFunction` with your function name
* 2. Update import path
* 3. Use test.each for data-driven tests
*/
// import { utilityFunction } from './utility'
// ============================================================================
// Tests
// ============================================================================
describe('utilityFunction', () => {
// --------------------------------------------------------------------------
// Basic Functionality
// --------------------------------------------------------------------------
describe('Basic Functionality', () => {
it('should return expected result for valid input', () => {
// expect(utilityFunction('input')).toBe('expected-output')
})
it('should handle multiple arguments', () => {
// expect(utilityFunction('a', 'b', 'c')).toBe('abc')
})
})
// --------------------------------------------------------------------------
// Data-Driven Tests
// --------------------------------------------------------------------------
describe('Input/Output Mapping', () => {
test.each([
// [input, expected]
['input1', 'output1'],
['input2', 'output2'],
['input3', 'output3'],
])('should return %s for input %s', (input, expected) => {
// expect(utilityFunction(input)).toBe(expected)
})
})
// --------------------------------------------------------------------------
// Edge Cases
// --------------------------------------------------------------------------
describe('Edge Cases', () => {
it('should handle empty string', () => {
// expect(utilityFunction('')).toBe('')
})
it('should handle null', () => {
// expect(utilityFunction(null)).toBe(null)
// or
// expect(() => utilityFunction(null)).toThrow()
})
it('should handle undefined', () => {
// expect(utilityFunction(undefined)).toBe(undefined)
// or
// expect(() => utilityFunction(undefined)).toThrow()
})
it('should handle empty array', () => {
// expect(utilityFunction([])).toEqual([])
})
it('should handle empty object', () => {
// expect(utilityFunction({})).toEqual({})
})
})
// --------------------------------------------------------------------------
// Boundary Conditions
// --------------------------------------------------------------------------
describe('Boundary Conditions', () => {
it('should handle minimum value', () => {
// expect(utilityFunction(0)).toBe(0)
})
it('should handle maximum value', () => {
// expect(utilityFunction(Number.MAX_SAFE_INTEGER)).toBe(...)
})
it('should handle negative numbers', () => {
// expect(utilityFunction(-1)).toBe(...)
})
})
// --------------------------------------------------------------------------
// Type Coercion (if applicable)
// --------------------------------------------------------------------------
describe('Type Handling', () => {
it('should handle numeric string', () => {
// expect(utilityFunction('123')).toBe(123)
})
it('should handle boolean', () => {
// expect(utilityFunction(true)).toBe(...)
})
})
// --------------------------------------------------------------------------
// Error Cases
// --------------------------------------------------------------------------
describe('Error Handling', () => {
it('should throw for invalid input', () => {
// expect(() => utilityFunction('invalid')).toThrow('Error message')
})
it('should throw with specific error type', () => {
// expect(() => utilityFunction('invalid')).toThrow(ValidationError)
})
})
// --------------------------------------------------------------------------
// Complex Objects (if applicable)
// --------------------------------------------------------------------------
describe('Object Handling', () => {
it('should preserve object structure', () => {
// const input = { a: 1, b: 2 }
// expect(utilityFunction(input)).toEqual({ a: 1, b: 2 })
})
it('should handle nested objects', () => {
// const input = { nested: { deep: 'value' } }
// expect(utilityFunction(input)).toEqual({ nested: { deep: 'transformed' } })
})
it('should not mutate input', () => {
// const input = { a: 1 }
// const inputCopy = { ...input }
// utilityFunction(input)
// expect(input).toEqual(inputCopy)
})
})
// --------------------------------------------------------------------------
// Array Handling (if applicable)
// --------------------------------------------------------------------------
describe('Array Handling', () => {
it('should process all elements', () => {
// expect(utilityFunction([1, 2, 3])).toEqual([2, 4, 6])
})
it('should handle single element array', () => {
// expect(utilityFunction([1])).toEqual([2])
})
it('should preserve order', () => {
// expect(utilityFunction(['c', 'a', 'b'])).toEqual(['c', 'a', 'b'])
})
})
})

View File

@ -0,0 +1,345 @@
# Async Testing Guide
## Core Async Patterns
### 1. waitFor - Wait for Condition
```typescript
import { render, screen, waitFor } from '@testing-library/react'
it('should load and display data', async () => {
render(<DataComponent />)
// Wait for element to appear
await waitFor(() => {
expect(screen.getByText('Loaded Data')).toBeInTheDocument()
})
})
it('should hide loading spinner after load', async () => {
render(<DataComponent />)
// Wait for element to disappear
await waitFor(() => {
expect(screen.queryByText('Loading...')).not.toBeInTheDocument()
})
})
```
### 2. findBy\* - Async Queries
```typescript
it('should show user name after fetch', async () => {
render(<UserProfile />)
// findBy returns a promise, auto-waits up to 1000ms
const userName = await screen.findByText('John Doe')
expect(userName).toBeInTheDocument()
// findByRole with options
const button = await screen.findByRole('button', { name: /submit/i })
expect(button).toBeEnabled()
})
```
### 3. userEvent for Async Interactions
```typescript
import userEvent from '@testing-library/user-event'
it('should submit form', async () => {
const user = userEvent.setup()
const onSubmit = vi.fn()
render(<Form onSubmit={onSubmit} />)
// userEvent methods are async
await user.type(screen.getByLabelText('Email'), 'test@example.com')
await user.click(screen.getByRole('button', { name: /submit/i }))
await waitFor(() => {
expect(onSubmit).toHaveBeenCalledWith({ email: 'test@example.com' })
})
})
```
## Fake Timers
### When to Use Fake Timers
- Testing components with `setTimeout`/`setInterval`
- Testing debounce/throttle behavior
- Testing animations or delayed transitions
- Testing polling or retry logic
### Basic Fake Timer Setup
```typescript
describe('Debounced Search', () => {
beforeEach(() => {
vi.useFakeTimers()
})
afterEach(() => {
vi.useRealTimers()
})
it('should debounce search input', async () => {
const onSearch = vi.fn()
render(<SearchInput onSearch={onSearch} debounceMs={300} />)
// Type in the input
fireEvent.change(screen.getByRole('textbox'), { target: { value: 'query' } })
// Search not called immediately
expect(onSearch).not.toHaveBeenCalled()
// Advance timers
vi.advanceTimersByTime(300)
// Now search is called
expect(onSearch).toHaveBeenCalledWith('query')
})
})
```
### Fake Timers with Async Code
```typescript
it('should retry on failure', async () => {
vi.useFakeTimers()
const fetchData = vi.fn()
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ data: 'success' })
render(<RetryComponent fetchData={fetchData} retryDelayMs={1000} />)
// First call fails
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(1)
})
// Advance timer for retry
vi.advanceTimersByTime(1000)
// Second call succeeds
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(2)
expect(screen.getByText('success')).toBeInTheDocument()
})
vi.useRealTimers()
})
```
### Common Fake Timer Utilities
```typescript
// Run all pending timers
vi.runAllTimers()
// Run only pending timers (not new ones created during execution)
vi.runOnlyPendingTimers()
// Advance by specific time
vi.advanceTimersByTime(1000)
// Get current fake time
Date.now()
// Clear all timers
vi.clearAllTimers()
```
## API Testing Patterns
### Loading → Success → Error States
```typescript
describe('DataFetcher', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should show loading state', () => {
mockedApi.fetchData.mockImplementation(() => new Promise(() => {})) // Never resolves
render(<DataFetcher />)
expect(screen.getByTestId('loading-spinner')).toBeInTheDocument()
})
it('should show data on success', async () => {
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1', 'Item 2'] })
render(<DataFetcher />)
// Use findBy* for multiple async elements (better error messages than waitFor with multiple assertions)
const item1 = await screen.findByText('Item 1')
const item2 = await screen.findByText('Item 2')
expect(item1).toBeInTheDocument()
expect(item2).toBeInTheDocument()
expect(screen.queryByTestId('loading-spinner')).not.toBeInTheDocument()
})
it('should show error on failure', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Failed to fetch'))
render(<DataFetcher />)
await waitFor(() => {
expect(screen.getByText(/failed to fetch/i)).toBeInTheDocument()
})
})
it('should retry on error', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
render(<DataFetcher />)
await waitFor(() => {
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
})
mockedApi.fetchData.mockResolvedValue({ items: ['Item 1'] })
fireEvent.click(screen.getByRole('button', { name: /retry/i }))
await waitFor(() => {
expect(screen.getByText('Item 1')).toBeInTheDocument()
})
})
})
```
### Testing Mutations
```typescript
it('should submit form and show success', async () => {
const user = userEvent.setup()
mockedApi.createItem.mockResolvedValue({ id: '1', name: 'New Item' })
render(<CreateItemForm />)
await user.type(screen.getByLabelText('Name'), 'New Item')
await user.click(screen.getByRole('button', { name: /create/i }))
// Button should be disabled during submission
expect(screen.getByRole('button', { name: /creating/i })).toBeDisabled()
await waitFor(() => {
expect(screen.getByText(/created successfully/i)).toBeInTheDocument()
})
expect(mockedApi.createItem).toHaveBeenCalledWith({ name: 'New Item' })
})
```
## useEffect Testing
### Testing Effect Execution
```typescript
it('should fetch data on mount', async () => {
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
render(<ComponentWithEffect fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledTimes(1)
})
})
```
### Testing Effect Dependencies
```typescript
it('should refetch when id changes', async () => {
const fetchData = vi.fn().mockResolvedValue({ data: 'test' })
const { rerender } = render(<ComponentWithEffect id="1" fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledWith('1')
})
rerender(<ComponentWithEffect id="2" fetchData={fetchData} />)
await waitFor(() => {
expect(fetchData).toHaveBeenCalledWith('2')
expect(fetchData).toHaveBeenCalledTimes(2)
})
})
```
### Testing Effect Cleanup
```typescript
it('should cleanup subscription on unmount', () => {
const subscribe = vi.fn()
const unsubscribe = vi.fn()
subscribe.mockReturnValue(unsubscribe)
const { unmount } = render(<SubscriptionComponent subscribe={subscribe} />)
expect(subscribe).toHaveBeenCalledTimes(1)
unmount()
expect(unsubscribe).toHaveBeenCalledTimes(1)
})
```
## Common Async Pitfalls
### ❌ Don't: Forget to await
```typescript
// Bad - test may pass even if assertion fails
it('should load data', () => {
render(<Component />)
waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
// Good - properly awaited
it('should load data', async () => {
render(<Component />)
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
```
### ❌ Don't: Use multiple assertions in single waitFor
```typescript
// Bad - if first assertion fails, won't know about second
await waitFor(() => {
expect(screen.getByText('Title')).toBeInTheDocument()
expect(screen.getByText('Description')).toBeInTheDocument()
})
// Good - separate waitFor or use findBy
const title = await screen.findByText('Title')
const description = await screen.findByText('Description')
expect(title).toBeInTheDocument()
expect(description).toBeInTheDocument()
```
### ❌ Don't: Mix fake timers with real async
```typescript
// Bad - fake timers don't work well with real Promises
vi.useFakeTimers()
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
}) // May timeout!
// Good - use runAllTimers or advanceTimersByTime
vi.useFakeTimers()
render(<Component />)
vi.runAllTimers()
expect(screen.getByText('Data')).toBeInTheDocument()
```

View File

@ -0,0 +1,205 @@
# Test Generation Checklist
Use this checklist when generating or reviewing tests for Dify frontend components.
## Pre-Generation
- [ ] Read the component source code completely
- [ ] Identify component type (component, hook, utility, page)
- [ ] Run `pnpm analyze-component <path>` if available
- [ ] Note complexity score and features detected
- [ ] Check for existing tests in the same directory
- [ ] **Identify ALL files in the directory** that need testing (not just index)
## Testing Strategy
### ⚠️ Incremental Workflow (CRITICAL for Multi-File)
- [ ] **NEVER generate all tests at once** - process one file at a time
- [ ] Order files by complexity: utilities → hooks → simple → complex → integration
- [ ] Create a todo list to track progress before starting
- [ ] For EACH file: write → run test → verify pass → then next
- [ ] **DO NOT proceed** to next file until current one passes
### Path-Level Coverage
- [ ] **Test ALL files** in the assigned directory/path
- [ ] List all components, hooks, utilities that need coverage
- [ ] Decide: single spec file (integration) or multiple spec files (unit)
### Complexity Assessment
- [ ] Run `pnpm analyze-component <path>` for complexity score
- [ ] **Complexity > 50**: Consider refactoring before testing
- [ ] **500+ lines**: Consider splitting before testing
- [ ] **30-50 complexity**: Use multiple describe blocks, organized structure
### Integration vs Mocking
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
- [ ] Import real project components instead of mocking
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
- [ ] Prefer integration testing when using single spec file
## Required Test Sections
### All Components MUST Have
- [ ] **Rendering tests** - Component renders without crashing
- [ ] **Props tests** - Required props, optional props, default values
- [ ] **Edge cases** - null, undefined, empty values, boundaries
### Conditional Sections (Add When Feature Present)
| Feature | Add Tests For |
|---------|---------------|
| `useState` | Initial state, transitions, cleanup |
| `useEffect` | Execution, dependencies, cleanup |
| Event handlers | onClick, onChange, onSubmit, keyboard |
| API calls | Loading, success, error states |
| Routing | Navigation, params, query strings |
| `useCallback`/`useMemo` | Referential equality |
| Context | Provider values, consumer behavior |
| Forms | Validation, submission, error display |
## Code Quality Checklist
### Structure
- [ ] Uses `describe` blocks to group related tests
- [ ] Test names follow `should <behavior> when <condition>` pattern
- [ ] AAA pattern (Arrange-Act-Assert) is clear
- [ ] Comments explain complex test scenarios
### Mocks
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
- [ ] Router mocks match actual Next.js API
- [ ] Mocks reflect actual component conditional behavior
- [ ] Only mock: API services, complex context providers, third-party libs
### Queries
- [ ] Prefer semantic queries (`getByRole`, `getByLabelText`)
- [ ] Use `queryBy*` for absence assertions
- [ ] Use `findBy*` for async elements
- [ ] `getByTestId` only as last resort
### Async
- [ ] All async tests use `async/await`
- [ ] `waitFor` wraps async assertions
- [ ] Fake timers properly setup/teardown
- [ ] No floating promises
### TypeScript
- [ ] No `any` types without justification
- [ ] Mock data uses actual types from source
- [ ] Factory functions have proper return types
## Coverage Goals (Per File)
For the current file being tested:
- [ ] 100% function coverage
- [ ] 100% statement coverage
- [ ] >95% branch coverage
- [ ] >95% line coverage
## Post-Generation (Per File)
**Run these checks after EACH test file, not just at the end:**
- [ ] Run `pnpm test path/to/file.spec.tsx` - **MUST PASS before next file**
- [ ] Fix any failures immediately
- [ ] Mark file as complete in todo list
- [ ] Only then proceed to next file
### After All Files Complete
- [ ] Run full directory test: `pnpm test path/to/directory/`
- [ ] Check coverage report: `pnpm test:coverage`
- [ ] Run `pnpm lint:fix` on all test files
- [ ] Run `pnpm type-check:tsgo`
## Common Issues to Watch
### False Positives
```typescript
// ❌ Mock doesn't match actual behavior
vi.mock('./Component', () => () => <div>Mocked</div>)
// ✅ Mock matches actual conditional logic
vi.mock('./Component', () => ({ isOpen }: any) =>
isOpen ? <div>Content</div> : null
)
```
### State Leakage
```typescript
// ❌ Shared state not reset
let mockState = false
vi.mock('./useHook', () => () => mockState)
// ✅ Reset in beforeEach
beforeEach(() => {
mockState = false
})
```
### Async Race Conditions
```typescript
// ❌ Not awaited
it('loads data', () => {
render(<Component />)
expect(screen.getByText('Data')).toBeInTheDocument()
})
// ✅ Properly awaited
it('loads data', async () => {
render(<Component />)
await waitFor(() => {
expect(screen.getByText('Data')).toBeInTheDocument()
})
})
```
### Missing Edge Cases
Always test these scenarios:
- `null` / `undefined` inputs
- Empty strings / arrays / objects
- Boundary values (0, -1, MAX_INT)
- Error states
- Loading states
- Disabled states
## Quick Commands
```bash
# Run specific test
pnpm test path/to/file.spec.tsx
# Run with coverage
pnpm test:coverage path/to/file.spec.tsx
# Watch mode
pnpm test:watch path/to/file.spec.tsx
# Update snapshots (use sparingly)
pnpm test -u path/to/file.spec.tsx
# Analyze component
pnpm analyze-component path/to/component.tsx
# Review existing test
pnpm analyze-component path/to/component.tsx --review
```

View File

@ -0,0 +1,449 @@
# Common Testing Patterns
## Query Priority
Use queries in this order (most to least preferred):
```typescript
// 1. getByRole - Most recommended (accessibility)
screen.getByRole('button', { name: /submit/i })
screen.getByRole('textbox', { name: /email/i })
screen.getByRole('heading', { level: 1 })
// 2. getByLabelText - Form fields
screen.getByLabelText('Email address')
screen.getByLabelText(/password/i)
// 3. getByPlaceholderText - When no label
screen.getByPlaceholderText('Search...')
// 4. getByText - Non-interactive elements
screen.getByText('Welcome to Dify')
screen.getByText(/loading/i)
// 5. getByDisplayValue - Current input value
screen.getByDisplayValue('current value')
// 6. getByAltText - Images
screen.getByAltText('Company logo')
// 7. getByTitle - Tooltip elements
screen.getByTitle('Close')
// 8. getByTestId - Last resort only!
screen.getByTestId('custom-element')
```
## Event Handling Patterns
### Click Events
```typescript
// Basic click
fireEvent.click(screen.getByRole('button'))
// With userEvent (preferred for realistic interaction)
const user = userEvent.setup()
await user.click(screen.getByRole('button'))
// Double click
await user.dblClick(screen.getByRole('button'))
// Right click
await user.pointer({ keys: '[MouseRight]', target: screen.getByRole('button') })
```
### Form Input
```typescript
const user = userEvent.setup()
// Type in input
await user.type(screen.getByRole('textbox'), 'Hello World')
// Clear and type
await user.clear(screen.getByRole('textbox'))
await user.type(screen.getByRole('textbox'), 'New value')
// Select option
await user.selectOptions(screen.getByRole('combobox'), 'option-value')
// Check checkbox
await user.click(screen.getByRole('checkbox'))
// Upload file
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
```
### Keyboard Events
```typescript
const user = userEvent.setup()
// Press Enter
await user.keyboard('{Enter}')
// Press Escape
await user.keyboard('{Escape}')
// Keyboard shortcut
await user.keyboard('{Control>}a{/Control}') // Ctrl+A
// Tab navigation
await user.tab()
// Arrow keys
await user.keyboard('{ArrowDown}')
await user.keyboard('{ArrowUp}')
```
## Component State Testing
### Testing State Transitions
```typescript
describe('Counter', () => {
it('should increment count', async () => {
const user = userEvent.setup()
render(<Counter initialCount={0} />)
// Initial state
expect(screen.getByText('Count: 0')).toBeInTheDocument()
// Trigger transition
await user.click(screen.getByRole('button', { name: /increment/i }))
// New state
expect(screen.getByText('Count: 1')).toBeInTheDocument()
})
})
```
### Testing Controlled Components
```typescript
describe('ControlledInput', () => {
it('should call onChange with new value', async () => {
const user = userEvent.setup()
const handleChange = vi.fn()
render(<ControlledInput value="" onChange={handleChange} />)
await user.type(screen.getByRole('textbox'), 'a')
expect(handleChange).toHaveBeenCalledWith('a')
})
it('should display controlled value', () => {
render(<ControlledInput value="controlled" onChange={vi.fn()} />)
expect(screen.getByRole('textbox')).toHaveValue('controlled')
})
})
```
## Conditional Rendering Testing
```typescript
describe('ConditionalComponent', () => {
it('should show loading state', () => {
render(<DataDisplay isLoading={true} data={null} />)
expect(screen.getByText(/loading/i)).toBeInTheDocument()
expect(screen.queryByTestId('data-content')).not.toBeInTheDocument()
})
it('should show error state', () => {
render(<DataDisplay isLoading={false} data={null} error="Failed to load" />)
expect(screen.getByText(/failed to load/i)).toBeInTheDocument()
})
it('should show data when loaded', () => {
render(<DataDisplay isLoading={false} data={{ name: 'Test' }} />)
expect(screen.getByText('Test')).toBeInTheDocument()
})
it('should show empty state when no data', () => {
render(<DataDisplay isLoading={false} data={[]} />)
expect(screen.getByText(/no data/i)).toBeInTheDocument()
})
})
```
## List Rendering Testing
```typescript
describe('ItemList', () => {
const items = [
{ id: '1', name: 'Item 1' },
{ id: '2', name: 'Item 2' },
{ id: '3', name: 'Item 3' },
]
it('should render all items', () => {
render(<ItemList items={items} />)
expect(screen.getAllByRole('listitem')).toHaveLength(3)
items.forEach(item => {
expect(screen.getByText(item.name)).toBeInTheDocument()
})
})
it('should handle item selection', async () => {
const user = userEvent.setup()
const onSelect = vi.fn()
render(<ItemList items={items} onSelect={onSelect} />)
await user.click(screen.getByText('Item 2'))
expect(onSelect).toHaveBeenCalledWith(items[1])
})
it('should handle empty list', () => {
render(<ItemList items={[]} />)
expect(screen.getByText(/no items/i)).toBeInTheDocument()
})
})
```
## Modal/Dialog Testing
```typescript
describe('Modal', () => {
it('should not render when closed', () => {
render(<Modal isOpen={false} onClose={vi.fn()} />)
expect(screen.queryByRole('dialog')).not.toBeInTheDocument()
})
it('should render when open', () => {
render(<Modal isOpen={true} onClose={vi.fn()} />)
expect(screen.getByRole('dialog')).toBeInTheDocument()
})
it('should call onClose when clicking overlay', async () => {
const user = userEvent.setup()
const handleClose = vi.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
await user.click(screen.getByTestId('modal-overlay'))
expect(handleClose).toHaveBeenCalled()
})
it('should call onClose when pressing Escape', async () => {
const user = userEvent.setup()
const handleClose = vi.fn()
render(<Modal isOpen={true} onClose={handleClose} />)
await user.keyboard('{Escape}')
expect(handleClose).toHaveBeenCalled()
})
it('should trap focus inside modal', async () => {
const user = userEvent.setup()
render(
<Modal isOpen={true} onClose={vi.fn()}>
<button>First</button>
<button>Second</button>
</Modal>
)
// Focus should cycle within modal
await user.tab()
expect(screen.getByText('First')).toHaveFocus()
await user.tab()
expect(screen.getByText('Second')).toHaveFocus()
await user.tab()
expect(screen.getByText('First')).toHaveFocus() // Cycles back
})
})
```
## Form Testing
```typescript
describe('LoginForm', () => {
it('should submit valid form', async () => {
const user = userEvent.setup()
const onSubmit = vi.fn()
render(<LoginForm onSubmit={onSubmit} />)
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
await user.type(screen.getByLabelText(/password/i), 'password123')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(onSubmit).toHaveBeenCalledWith({
email: 'test@example.com',
password: 'password123',
})
})
it('should show validation errors', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={vi.fn()} />)
// Submit empty form
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByText(/email is required/i)).toBeInTheDocument()
expect(screen.getByText(/password is required/i)).toBeInTheDocument()
})
it('should validate email format', async () => {
const user = userEvent.setup()
render(<LoginForm onSubmit={vi.fn()} />)
await user.type(screen.getByLabelText(/email/i), 'invalid-email')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByText(/invalid email/i)).toBeInTheDocument()
})
it('should disable submit button while submitting', async () => {
const user = userEvent.setup()
const onSubmit = vi.fn(() => new Promise(resolve => setTimeout(resolve, 100)))
render(<LoginForm onSubmit={onSubmit} />)
await user.type(screen.getByLabelText(/email/i), 'test@example.com')
await user.type(screen.getByLabelText(/password/i), 'password123')
await user.click(screen.getByRole('button', { name: /sign in/i }))
expect(screen.getByRole('button', { name: /signing in/i })).toBeDisabled()
await waitFor(() => {
expect(screen.getByRole('button', { name: /sign in/i })).toBeEnabled()
})
})
})
```
## Data-Driven Tests with test.each
```typescript
describe('StatusBadge', () => {
test.each([
['success', 'bg-green-500'],
['warning', 'bg-yellow-500'],
['error', 'bg-red-500'],
['info', 'bg-blue-500'],
])('should apply correct class for %s status', (status, expectedClass) => {
render(<StatusBadge status={status} />)
expect(screen.getByTestId('status-badge')).toHaveClass(expectedClass)
})
test.each([
{ input: null, expected: 'Unknown' },
{ input: undefined, expected: 'Unknown' },
{ input: '', expected: 'Unknown' },
{ input: 'invalid', expected: 'Unknown' },
])('should show "Unknown" for invalid input: $input', ({ input, expected }) => {
render(<StatusBadge status={input} />)
expect(screen.getByText(expected)).toBeInTheDocument()
})
})
```
## Debugging Tips
```typescript
// Print entire DOM
screen.debug()
// Print specific element
screen.debug(screen.getByRole('button'))
// Log testing playground URL
screen.logTestingPlaygroundURL()
// Pretty print DOM
import { prettyDOM } from '@testing-library/react'
console.log(prettyDOM(screen.getByRole('dialog')))
// Check available roles
import { getRoles } from '@testing-library/react'
console.log(getRoles(container))
```
## Common Mistakes to Avoid
### ❌ Don't Use Implementation Details
```typescript
// Bad - testing implementation
expect(component.state.isOpen).toBe(true)
expect(wrapper.find('.internal-class').length).toBe(1)
// Good - testing behavior
expect(screen.getByRole('dialog')).toBeInTheDocument()
```
### ❌ Don't Forget Cleanup
```typescript
// Bad - may leak state between tests
it('test 1', () => {
render(<Component />)
})
// Good - cleanup is automatic with RTL, but reset mocks
beforeEach(() => {
vi.clearAllMocks()
})
```
### ❌ Don't Use Exact String Matching (Prefer Black-Box Assertions)
```typescript
// ❌ Bad - hardcoded strings are brittle
expect(screen.getByText('Submit Form')).toBeInTheDocument()
expect(screen.getByText('Loading...')).toBeInTheDocument()
// ✅ Good - role-based queries (most semantic)
expect(screen.getByRole('button', { name: /submit/i })).toBeInTheDocument()
expect(screen.getByRole('status')).toBeInTheDocument()
// ✅ Good - pattern matching (flexible)
expect(screen.getByText(/submit/i)).toBeInTheDocument()
expect(screen.getByText(/loading/i)).toBeInTheDocument()
// ✅ Good - test behavior, not exact UI text
expect(screen.getByRole('button')).toBeDisabled()
expect(screen.getByRole('alert')).toBeInTheDocument()
```
**Why prefer black-box assertions?**
- Text content may change (i18n, copy updates)
- Role-based queries test accessibility
- Pattern matching is resilient to minor changes
- Tests focus on behavior, not implementation details
### ❌ Don't Assert on Absence Without Query
```typescript
// Bad - throws if not found
expect(screen.getByText('Error')).not.toBeInTheDocument() // Error!
// Good - use queryBy for absence assertions
expect(screen.queryByText('Error')).not.toBeInTheDocument()
```

View File

@ -0,0 +1,523 @@
# Domain-Specific Component Testing
This guide covers testing patterns for Dify's domain-specific components.
## Workflow Components (`workflow/`)
Workflow components handle node configuration, data flow, and graph operations.
### Key Test Areas
1. **Node Configuration**
1. **Data Validation**
1. **Variable Passing**
1. **Edge Connections**
1. **Error Handling**
### Example: Node Configuration Panel
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import NodeConfigPanel from './node-config-panel'
import { createMockNode, createMockWorkflowContext } from '@/__mocks__/workflow'
// Mock workflow context
vi.mock('@/app/components/workflow/hooks', () => ({
useWorkflowStore: () => mockWorkflowStore,
useNodesInteractions: () => mockNodesInteractions,
}))
let mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: vi.fn(),
}
let mockNodesInteractions = {
handleNodeSelect: vi.fn(),
handleNodeDelete: vi.fn(),
}
describe('NodeConfigPanel', () => {
beforeEach(() => {
vi.clearAllMocks()
mockWorkflowStore = {
nodes: [],
edges: [],
updateNode: vi.fn(),
}
})
describe('Node Configuration', () => {
it('should render node type selector', () => {
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
expect(screen.getByLabelText(/model/i)).toBeInTheDocument()
})
it('should update node config on change', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
await user.selectOptions(screen.getByLabelText(/model/i), 'gpt-4')
expect(mockWorkflowStore.updateNode).toHaveBeenCalledWith(
node.id,
expect.objectContaining({ model: 'gpt-4' })
)
})
})
describe('Data Validation', () => {
it('should show error for invalid input', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'code' })
render(<NodeConfigPanel node={node} />)
// Enter invalid code
const codeInput = screen.getByLabelText(/code/i)
await user.clear(codeInput)
await user.type(codeInput, 'invalid syntax {{{')
await waitFor(() => {
expect(screen.getByText(/syntax error/i)).toBeInTheDocument()
})
})
it('should validate required fields', async () => {
const node = createMockNode({ type: 'http', data: { url: '' } })
render(<NodeConfigPanel node={node} />)
fireEvent.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(screen.getByText(/url is required/i)).toBeInTheDocument()
})
})
})
describe('Variable Passing', () => {
it('should display available variables from upstream nodes', () => {
const upstreamNode = createMockNode({
id: 'node-1',
type: 'start',
data: { outputs: [{ name: 'user_input', type: 'string' }] },
})
const currentNode = createMockNode({
id: 'node-2',
type: 'llm',
})
mockWorkflowStore.nodes = [upstreamNode, currentNode]
mockWorkflowStore.edges = [{ source: 'node-1', target: 'node-2' }]
render(<NodeConfigPanel node={currentNode} />)
// Variable selector should show upstream variables
fireEvent.click(screen.getByRole('button', { name: /add variable/i }))
expect(screen.getByText('user_input')).toBeInTheDocument()
})
it('should insert variable into prompt template', async () => {
const user = userEvent.setup()
const node = createMockNode({ type: 'llm' })
render(<NodeConfigPanel node={node} />)
// Click variable button
await user.click(screen.getByRole('button', { name: /insert variable/i }))
await user.click(screen.getByText('user_input'))
const promptInput = screen.getByLabelText(/prompt/i)
expect(promptInput).toHaveValue(expect.stringContaining('{{user_input}}'))
})
})
})
```
## Dataset Components (`dataset/`)
Dataset components handle file uploads, data display, and search/filter operations.
### Key Test Areas
1. **File Upload**
1. **File Type Validation**
1. **Pagination**
1. **Search & Filtering**
1. **Data Format Handling**
### Example: Document Uploader
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import DocumentUploader from './document-uploader'
vi.mock('@/service/datasets', () => ({
uploadDocument: vi.fn(),
parseDocument: vi.fn(),
}))
import * as datasetService from '@/service/datasets'
const mockedService = vi.mocked(datasetService)
describe('DocumentUploader', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('File Upload', () => {
it('should accept valid file types', async () => {
const user = userEvent.setup()
const onUpload = vi.fn()
mockedService.uploadDocument.mockResolvedValue({ id: 'doc-1' })
render(<DocumentUploader onUpload={onUpload} />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
const input = screen.getByLabelText(/upload/i)
await user.upload(input, file)
await waitFor(() => {
expect(mockedService.uploadDocument).toHaveBeenCalledWith(
expect.any(FormData)
)
})
})
it('should reject invalid file types', async () => {
const user = userEvent.setup()
render(<DocumentUploader />)
const file = new File(['content'], 'test.exe', { type: 'application/x-msdownload' })
const input = screen.getByLabelText(/upload/i)
await user.upload(input, file)
expect(screen.getByText(/unsupported file type/i)).toBeInTheDocument()
expect(mockedService.uploadDocument).not.toHaveBeenCalled()
})
it('should show upload progress', async () => {
const user = userEvent.setup()
// Mock upload with progress
mockedService.uploadDocument.mockImplementation(() => {
return new Promise((resolve) => {
setTimeout(() => resolve({ id: 'doc-1' }), 100)
})
})
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
expect(screen.getByRole('progressbar')).toBeInTheDocument()
await waitFor(() => {
expect(screen.queryByRole('progressbar')).not.toBeInTheDocument()
})
})
})
describe('Error Handling', () => {
it('should handle upload failure', async () => {
const user = userEvent.setup()
mockedService.uploadDocument.mockRejectedValue(new Error('Upload failed'))
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
await waitFor(() => {
expect(screen.getByText(/upload failed/i)).toBeInTheDocument()
})
})
it('should allow retry after failure', async () => {
const user = userEvent.setup()
mockedService.uploadDocument
.mockRejectedValueOnce(new Error('Network error'))
.mockResolvedValueOnce({ id: 'doc-1' })
render(<DocumentUploader />)
const file = new File(['content'], 'test.pdf', { type: 'application/pdf' })
await user.upload(screen.getByLabelText(/upload/i), file)
await waitFor(() => {
expect(screen.getByRole('button', { name: /retry/i })).toBeInTheDocument()
})
await user.click(screen.getByRole('button', { name: /retry/i }))
await waitFor(() => {
expect(screen.getByText(/uploaded successfully/i)).toBeInTheDocument()
})
})
})
})
```
### Example: Document List with Pagination
```typescript
describe('DocumentList', () => {
describe('Pagination', () => {
it('should load first page on mount', async () => {
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '1', name: 'Doc 1' }],
total: 50,
page: 1,
pageSize: 10,
})
render(<DocumentList datasetId="ds-1" />)
await waitFor(() => {
expect(screen.getByText('Doc 1')).toBeInTheDocument()
})
expect(mockedService.getDocuments).toHaveBeenCalledWith('ds-1', { page: 1 })
})
it('should navigate to next page', async () => {
const user = userEvent.setup()
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '1', name: 'Doc 1' }],
total: 50,
page: 1,
pageSize: 10,
})
render(<DocumentList datasetId="ds-1" />)
await waitFor(() => {
expect(screen.getByText('Doc 1')).toBeInTheDocument()
})
mockedService.getDocuments.mockResolvedValue({
data: [{ id: '11', name: 'Doc 11' }],
total: 50,
page: 2,
pageSize: 10,
})
await user.click(screen.getByRole('button', { name: /next/i }))
await waitFor(() => {
expect(screen.getByText('Doc 11')).toBeInTheDocument()
})
})
})
describe('Search & Filtering', () => {
it('should filter by search query', async () => {
const user = userEvent.setup()
vi.useFakeTimers()
render(<DocumentList datasetId="ds-1" />)
await user.type(screen.getByPlaceholderText(/search/i), 'test query')
// Debounce
vi.advanceTimersByTime(300)
await waitFor(() => {
expect(mockedService.getDocuments).toHaveBeenCalledWith(
'ds-1',
expect.objectContaining({ search: 'test query' })
)
})
vi.useRealTimers()
})
})
})
```
## Configuration Components (`app/configuration/`, `config/`)
Configuration components handle forms, validation, and data persistence.
### Key Test Areas
1. **Form Validation**
1. **Save/Reset**
1. **Required vs Optional Fields**
1. **Configuration Persistence**
1. **Error Feedback**
### Example: App Configuration Form
```typescript
import { render, screen, fireEvent, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import AppConfigForm from './app-config-form'
vi.mock('@/service/apps', () => ({
updateAppConfig: vi.fn(),
getAppConfig: vi.fn(),
}))
import * as appService from '@/service/apps'
const mockedService = vi.mocked(appService)
describe('AppConfigForm', () => {
const defaultConfig = {
name: 'My App',
description: '',
icon: 'default',
openingStatement: '',
}
beforeEach(() => {
vi.clearAllMocks()
mockedService.getAppConfig.mockResolvedValue(defaultConfig)
})
describe('Form Validation', () => {
it('should require app name', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Clear name field
await user.clear(screen.getByLabelText(/name/i))
await user.click(screen.getByRole('button', { name: /save/i }))
expect(screen.getByText(/name is required/i)).toBeInTheDocument()
expect(mockedService.updateAppConfig).not.toHaveBeenCalled()
})
it('should validate name length', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toBeInTheDocument()
})
// Enter very long name
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'a'.repeat(101))
expect(screen.getByText(/name must be less than 100 characters/i)).toBeInTheDocument()
})
it('should allow empty optional fields', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockResolvedValue({ success: true })
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Leave description empty (optional)
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(mockedService.updateAppConfig).toHaveBeenCalled()
})
})
})
describe('Save/Reset Functionality', () => {
it('should save configuration', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockResolvedValue({ success: true })
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'Updated App')
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(mockedService.updateAppConfig).toHaveBeenCalledWith(
'app-1',
expect.objectContaining({ name: 'Updated App' })
)
})
expect(screen.getByText(/saved successfully/i)).toBeInTheDocument()
})
it('should reset to default values', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Make changes
await user.clear(screen.getByLabelText(/name/i))
await user.type(screen.getByLabelText(/name/i), 'Changed Name')
// Reset
await user.click(screen.getByRole('button', { name: /reset/i }))
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
it('should show unsaved changes warning', async () => {
const user = userEvent.setup()
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
// Make changes
await user.type(screen.getByLabelText(/name/i), ' Updated')
expect(screen.getByText(/unsaved changes/i)).toBeInTheDocument()
})
})
describe('Error Handling', () => {
it('should show error on save failure', async () => {
const user = userEvent.setup()
mockedService.updateAppConfig.mockRejectedValue(new Error('Server error'))
render(<AppConfigForm appId="app-1" />)
await waitFor(() => {
expect(screen.getByLabelText(/name/i)).toHaveValue('My App')
})
await user.click(screen.getByRole('button', { name: /save/i }))
await waitFor(() => {
expect(screen.getByText(/failed to save/i)).toBeInTheDocument()
})
})
})
})
```

View File

@ -0,0 +1,343 @@
# Mocking Guide for Dify Frontend Tests
## ⚠️ Important: What NOT to Mock
### DO NOT Mock Base Components
**Never mock components from `@/app/components/base/`** such as:
- `Loading`, `Spinner`
- `Button`, `Input`, `Select`
- `Tooltip`, `Modal`, `Dropdown`
- `Icon`, `Badge`, `Tag`
**Why?**
- Base components will have their own dedicated tests
- Mocking them creates false positives (tests pass but real integration fails)
- Using real components tests actual integration behavior
```typescript
// ❌ WRONG: Don't mock base components
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
// ✅ CORRECT: Import and use real base components
import Loading from '@/app/components/base/loading'
import Button from '@/app/components/base/button'
// They will render normally in tests
```
### What TO Mock
Only mock these categories:
1. **API services** (`@/service/*`) - Network calls
1. **Complex context providers** - When setup is too difficult
1. **Third-party libraries with side effects** - `next/navigation`, external SDKs
1. **i18n** - Always mock to return keys
## Mock Placement
| Location | Purpose |
|----------|---------|
| `web/vitest.setup.ts` | Global mocks shared by all tests (for example `react-i18next`, `next/image`) |
| `web/__mocks__/` | Reusable mock factories shared across multiple test files |
| Test file | Test-specific mocks, inline with `vi.mock()` |
Modules are not mocked automatically. Use `vi.mock` in test files, or add global mocks in `web/vitest.setup.ts`.
## Essential Mocks
### 1. i18n (Auto-loaded via Global Mock)
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
**No explicit mock needed** for most tests - it returns translation keys as-is.
For tests requiring custom translations, override the mock:
```typescript
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
'my.custom.key': 'Custom translation',
}
return translations[key] || key
},
}),
}))
```
### 2. Next.js Router
```typescript
const mockPush = vi.fn()
const mockReplace = vi.fn()
vi.mock('next/navigation', () => ({
useRouter: () => ({
push: mockPush,
replace: mockReplace,
back: vi.fn(),
prefetch: vi.fn(),
}),
usePathname: () => '/current-path',
useSearchParams: () => new URLSearchParams('?key=value'),
}))
describe('Component', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should navigate on click', () => {
render(<Component />)
fireEvent.click(screen.getByRole('button'))
expect(mockPush).toHaveBeenCalledWith('/expected-path')
})
})
```
### 3. Portal Components (with Shared State)
```typescript
// ⚠️ Important: Use shared state for components that depend on each other
let mockPortalOpenState = false
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children, open, ...props }: any) => {
mockPortalOpenState = open || false // Update shared state
return <div data-testid="portal" data-open={open}>{children}</div>
},
PortalToFollowElemContent: ({ children }: any) => {
// ✅ Matches actual: returns null when portal is closed
if (!mockPortalOpenState) return null
return <div data-testid="portal-content">{children}</div>
},
PortalToFollowElemTrigger: ({ children }: any) => (
<div data-testid="portal-trigger">{children}</div>
),
}))
describe('Component', () => {
beforeEach(() => {
vi.clearAllMocks()
mockPortalOpenState = false // ✅ Reset shared state
})
})
```
### 4. API Service Mocks
```typescript
import * as api from '@/service/api'
vi.mock('@/service/api')
const mockedApi = vi.mocked(api)
describe('Component', () => {
beforeEach(() => {
vi.clearAllMocks()
// Setup default mock implementation
mockedApi.fetchData.mockResolvedValue({ data: [] })
})
it('should show data on success', async () => {
mockedApi.fetchData.mockResolvedValue({ data: [{ id: 1 }] })
render(<Component />)
await waitFor(() => {
expect(screen.getByText('1')).toBeInTheDocument()
})
})
it('should show error on failure', async () => {
mockedApi.fetchData.mockRejectedValue(new Error('Network error'))
render(<Component />)
await waitFor(() => {
expect(screen.getByText(/error/i)).toBeInTheDocument()
})
})
})
```
### 5. HTTP Mocking with Nock
```typescript
import nock from 'nock'
const GITHUB_HOST = 'https://api.github.com'
const GITHUB_PATH = '/repos/owner/repo'
const mockGithubApi = (status: number, body: Record<string, unknown>, delayMs = 0) => {
return nock(GITHUB_HOST)
.get(GITHUB_PATH)
.delay(delayMs)
.reply(status, body)
}
describe('GithubComponent', () => {
afterEach(() => {
nock.cleanAll()
})
it('should display repo info', async () => {
mockGithubApi(200, { name: 'dify', stars: 1000 })
render(<GithubComponent />)
await waitFor(() => {
expect(screen.getByText('dify')).toBeInTheDocument()
})
})
it('should handle API error', async () => {
mockGithubApi(500, { message: 'Server error' })
render(<GithubComponent />)
await waitFor(() => {
expect(screen.getByText(/error/i)).toBeInTheDocument()
})
})
})
```
### 6. Context Providers
```typescript
import { ProviderContext } from '@/context/provider-context'
import { createMockProviderContextValue, createMockPlan } from '@/__mocks__/provider-context'
describe('Component with Context', () => {
it('should render for free plan', () => {
const mockContext = createMockPlan('sandbox')
render(
<ProviderContext.Provider value={mockContext}>
<Component />
</ProviderContext.Provider>
)
expect(screen.getByText('Upgrade')).toBeInTheDocument()
})
it('should render for pro plan', () => {
const mockContext = createMockPlan('professional')
render(
<ProviderContext.Provider value={mockContext}>
<Component />
</ProviderContext.Provider>
)
expect(screen.queryByText('Upgrade')).not.toBeInTheDocument()
})
})
```
### 7. React Query
```typescript
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
const createTestQueryClient = () => new QueryClient({
defaultOptions: {
queries: { retry: false },
mutations: { retry: false },
},
})
const renderWithQueryClient = (ui: React.ReactElement) => {
const queryClient = createTestQueryClient()
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>
)
}
```
## Mock Best Practices
### ✅ DO
1. **Use real base components** - Import from `@/app/components/base/` directly
1. **Use real project components** - Prefer importing over mocking
1. **Reset mocks in `beforeEach`**, not `afterEach`
1. **Match actual component behavior** in mocks (when mocking is necessary)
1. **Use factory functions** for complex mock data
1. **Import actual types** for type safety
1. **Reset shared mock state** in `beforeEach`
### ❌ DON'T
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
1. Don't forget to clean up nock after each test
1. Don't use `any` types in mocks without necessity
### Mock Decision Tree
```
Need to use a component in test?
├─ Is it from @/app/components/base/*?
│ └─ YES → Import real component, DO NOT mock
├─ Is it a project component?
│ └─ YES → Prefer importing real component
│ Only mock if setup is extremely complex
├─ Is it an API service (@/service/*)?
│ └─ YES → Mock it
├─ Is it a third-party lib with side effects?
│ └─ YES → Mock it (next/navigation, external SDKs)
└─ Is it i18n?
└─ YES → Uses shared mock (auto-loaded). Override only for custom translations
```
## Factory Function Pattern
```typescript
// __mocks__/data-factories.ts
import type { User, Project } from '@/types'
export const createMockUser = (overrides: Partial<User> = {}): User => ({
id: 'user-1',
name: 'Test User',
email: 'test@example.com',
role: 'member',
createdAt: new Date().toISOString(),
...overrides,
})
export const createMockProject = (overrides: Partial<Project> = {}): Project => ({
id: 'project-1',
name: 'Test Project',
description: 'A test project',
owner: createMockUser(),
members: [],
createdAt: new Date().toISOString(),
...overrides,
})
// Usage in tests
it('should display project owner', () => {
const project = createMockProject({
owner: createMockUser({ name: 'John Doe' }),
})
render(<ProjectCard project={project} />)
expect(screen.getByText('John Doe')).toBeInTheDocument()
})
```

View File

@ -0,0 +1,269 @@
# Testing Workflow Guide
This guide defines the workflow for generating tests, especially for complex components or directories with multiple files.
## Scope Clarification
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
| Scope | Rule |
|-------|------|
| **Single file** | Complete coverage in one generation (100% function, >95% branch) |
| **Multi-file directory** | Process one file at a time, verify each before proceeding |
## ⚠️ Critical Rule: Incremental Approach for Multi-File Testing
When testing a **directory with multiple files**, **NEVER generate all test files at once.** Use an incremental, verify-as-you-go approach.
### Why Incremental?
| Batch Approach (❌) | Incremental Approach (✅) |
|---------------------|---------------------------|
| Generate 5+ tests at once | Generate 1 test at a time |
| Run tests only at the end | Run test immediately after each file |
| Multiple failures compound | Single point of failure, easy to debug |
| Hard to identify root cause | Clear cause-effect relationship |
| Mock issues affect many files | Mock issues caught early |
| Messy git history | Clean, atomic commits possible |
## Single File Workflow
When testing a **single component, hook, or utility**:
```
1. Read source code completely
2. Run `pnpm analyze-component <path>` (if available)
3. Check complexity score and features detected
4. Write the test file
5. Run test: `pnpm test <file>.spec.tsx`
6. Fix any failures
7. Verify coverage meets goals (100% function, >95% branch)
```
## Directory/Multi-File Workflow (MUST FOLLOW)
When testing a **directory or multiple files**, follow this strict workflow:
### Step 1: Analyze and Plan
1. **List all files** that need tests in the directory
1. **Categorize by complexity**:
- 🟢 **Simple**: Utility functions, simple hooks, presentational components
- 🟡 **Medium**: Components with state, effects, or event handlers
- 🔴 **Complex**: Components with API calls, routing, or many dependencies
1. **Order by dependency**: Test dependencies before dependents
1. **Create a todo list** to track progress
### Step 2: Determine Processing Order
Process files in this recommended order:
```
1. Utility functions (simplest, no React)
2. Custom hooks (isolated logic)
3. Simple presentational components (few/no props)
4. Medium complexity components (state, effects)
5. Complex components (API, routing, many deps)
6. Container/index components (integration tests - last)
```
**Rationale**:
- Simpler files help establish mock patterns
- Hooks used by components should be tested first
- Integration tests (index files) depend on child components working
### Step 3: Process Each File Incrementally
**For EACH file in the ordered list:**
```
┌─────────────────────────────────────────────┐
│ 1. Write test file │
│ 2. Run: pnpm test <file>.spec.tsx │
│ 3. If FAIL → Fix immediately, re-run │
│ 4. If PASS → Mark complete in todo list │
│ 5. ONLY THEN proceed to next file │
└─────────────────────────────────────────────┘
```
**DO NOT proceed to the next file until the current one passes.**
### Step 4: Final Verification
After all individual tests pass:
```bash
# Run all tests in the directory together
pnpm test path/to/directory/
# Check coverage
pnpm test:coverage path/to/directory/
```
## Component Complexity Guidelines
Use `pnpm analyze-component <path>` to assess complexity before testing.
### 🔴 Very Complex Components (Complexity > 50)
**Consider refactoring BEFORE testing:**
- Break component into smaller, testable pieces
- Extract complex logic into custom hooks
- Separate container and presentational layers
**If testing as-is:**
- Use integration tests for complex workflows
- Use `test.each()` for data-driven testing
- Multiple `describe` blocks for organization
- Consider testing major sections separately
### 🟡 Medium Complexity (Complexity 30-50)
- Group related tests in `describe` blocks
- Test integration scenarios between internal parts
- Focus on state transitions and side effects
- Use helper functions to reduce test complexity
### 🟢 Simple Components (Complexity < 30)
- Standard test structure
- Focus on props, rendering, and edge cases
- Usually straightforward to test
### 📏 Large Files (500+ lines)
Regardless of complexity score:
- **Strongly consider refactoring** before testing
- If testing as-is, test major sections separately
- Create helper functions for test setup
- May need multiple test files
## Todo List Format
When testing multiple files, use a todo list like this:
```
Testing: path/to/directory/
Ordered by complexity (simple → complex):
☐ utils/helper.ts [utility, simple]
☐ hooks/use-custom-hook.ts [hook, simple]
☐ empty-state.tsx [component, simple]
☐ item-card.tsx [component, medium]
☐ list.tsx [component, complex]
☐ index.tsx [integration]
Progress: 0/6 complete
```
Update status as you complete each:
- ☐ → ⏳ (in progress)
- ⏳ → ✅ (complete and verified)
- ⏳ → ❌ (blocked, needs attention)
## When to Stop and Verify
**Always run tests after:**
- Completing a test file
- Making changes to fix a failure
- Modifying shared mocks
- Updating test utilities or helpers
**Signs you should pause:**
- More than 2 consecutive test failures
- Mock-related errors appearing
- Unclear why a test is failing
- Test passing but coverage unexpectedly low
## Common Pitfalls to Avoid
### ❌ Don't: Generate Everything First
```
# BAD: Writing all files then testing
Write component-a.spec.tsx
Write component-b.spec.tsx
Write component-c.spec.tsx
Write component-d.spec.tsx
Run pnpm test ← Multiple failures, hard to debug
```
### ✅ Do: Verify Each Step
```
# GOOD: Incremental with verification
Write component-a.spec.tsx
Run pnpm test component-a.spec.tsx ✅
Write component-b.spec.tsx
Run pnpm test component-b.spec.tsx ✅
...continue...
```
### ❌ Don't: Skip Verification for "Simple" Components
Even simple components can have:
- Import errors
- Missing mock setup
- Incorrect assumptions about props
**Always verify, regardless of perceived simplicity.**
### ❌ Don't: Continue When Tests Fail
Failing tests compound:
- A mock issue in file A affects files B, C, D
- Fixing A later requires revisiting all dependent tests
- Time wasted on debugging cascading failures
**Fix failures immediately before proceeding.**
## Integration with Claude's Todo Feature
When using Claude for multi-file testing:
1. **Ask Claude to create a todo list** before starting
1. **Request one file at a time** or ensure Claude processes incrementally
1. **Verify each test passes** before asking for the next
1. **Mark todos complete** as you progress
Example prompt:
```
Test all components in `path/to/directory/`.
First, analyze the directory and create a todo list ordered by complexity.
Then, process ONE file at a time, waiting for my confirmation that tests pass
before proceeding to the next.
```
## Summary Checklist
Before starting multi-file testing:
- [ ] Listed all files needing tests
- [ ] Ordered by complexity (simple → complex)
- [ ] Created todo list for tracking
- [ ] Understand dependencies between files
During testing:
- [ ] Processing ONE file at a time
- [ ] Running tests after EACH file
- [ ] Fixing failures BEFORE proceeding
- [ ] Updating todo list progress
After completion:
- [ ] All individual tests pass
- [ ] Full directory test run passes
- [ ] Coverage goals met
- [ ] Todo list shows all complete

1
.codex/skills Symbolic link
View File

@ -0,0 +1 @@
../.claude/skills

5
.coveragerc Normal file
View File

@ -0,0 +1,5 @@
[run]
omit =
api/tests/*
api/migrations/*
api/core/rag/datasource/vdb/*

View File

@ -6,6 +6,9 @@
"context": "..", "context": "..",
"dockerfile": "Dockerfile" "dockerfile": "Dockerfile"
}, },
"mounts": [
"source=dify-dev-tmp,target=/tmp,type=volume"
],
"features": { "features": {
"ghcr.io/devcontainers/features/node:1": { "ghcr.io/devcontainers/features/node:1": {
"nodeGypDependencies": true, "nodeGypDependencies": true,
@ -34,19 +37,13 @@
}, },
"postStartCommand": "./.devcontainer/post_start_command.sh", "postStartCommand": "./.devcontainer/post_start_command.sh",
"postCreateCommand": "./.devcontainer/post_create_command.sh" "postCreateCommand": "./.devcontainer/post_create_command.sh"
// Features to add to the dev container. More info: https://containers.dev/features. // Features to add to the dev container. More info: https://containers.dev/features.
// "features": {}, // "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally. // Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [], // "forwardPorts": [],
// Use 'postCreateCommand' to run commands after the container is created. // Use 'postCreateCommand' to run commands after the container is created.
// "postCreateCommand": "python --version", // "postCreateCommand": "python --version",
// Configure tool-specific properties. // Configure tool-specific properties.
// "customizations": {}, // "customizations": {},
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
} }

View File

@ -1,12 +1,13 @@
#!/bin/bash #!/bin/bash
WORKSPACE_ROOT=$(pwd) WORKSPACE_ROOT=$(pwd)
export COREPACK_ENABLE_DOWNLOAD_PROMPT=0
corepack enable corepack enable
cd web && pnpm install cd web && pnpm install
pipx install uv pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

309
.github/CODEOWNERS vendored
View File

@ -6,221 +6,244 @@
* @crazywoola @laipz8200 @Yeuoly * @crazywoola @laipz8200 @Yeuoly
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola
# Docs
/docs/ @crazywoola
# Backend (default owner, more specific rules below will override) # Backend (default owner, more specific rules below will override)
api/ @QuantumGhost /api/ @QuantumGhost
# Backend - MCP
/api/core/mcp/ @Nov1c444
/api/core/entities/mcp_provider.py @Nov1c444
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
/api/controllers/mcp/ @Nov1c444
/api/controllers/console/app/mcp_server.py @Nov1c444
/api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine) # Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost /api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost /api/core/workflow/runtime/ @laipz8200 @QuantumGhost
api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph/ @laipz8200 @QuantumGhost
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
api/core/workflow/node_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost
api/core/model_runtime/ @laipz8200 @QuantumGhost /api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
api/core/workflow/nodes/agent/ @Nov1c444 /api/core/workflow/nodes/agent/ @Nov1c444
api/core/workflow/nodes/iteration/ @Nov1c444 /api/core/workflow/nodes/iteration/ @Nov1c444
api/core/workflow/nodes/loop/ @Nov1c444 /api/core/workflow/nodes/loop/ @Nov1c444
api/core/workflow/nodes/llm/ @Nov1c444 /api/core/workflow/nodes/llm/ @Nov1c444
# Backend - RAG (Retrieval Augmented Generation) # Backend - RAG (Retrieval Augmented Generation)
api/core/rag/ @JohnJyong /api/core/rag/ @JohnJyong
api/services/rag_pipeline/ @JohnJyong /api/services/rag_pipeline/ @JohnJyong
api/services/dataset_service.py @JohnJyong /api/services/dataset_service.py @JohnJyong
api/services/knowledge_service.py @JohnJyong /api/services/knowledge_service.py @JohnJyong
api/services/external_knowledge_service.py @JohnJyong /api/services/external_knowledge_service.py @JohnJyong
api/services/hit_testing_service.py @JohnJyong /api/services/hit_testing_service.py @JohnJyong
api/services/metadata_service.py @JohnJyong /api/services/metadata_service.py @JohnJyong
api/services/vector_service.py @JohnJyong /api/services/vector_service.py @JohnJyong
api/services/entities/knowledge_entities/ @JohnJyong /api/services/entities/knowledge_entities/ @JohnJyong
api/services/entities/external_knowledge_entities/ @JohnJyong /api/services/entities/external_knowledge_entities/ @JohnJyong
api/controllers/console/datasets/ @JohnJyong /api/controllers/console/datasets/ @JohnJyong
api/controllers/service_api/dataset/ @JohnJyong /api/controllers/service_api/dataset/ @JohnJyong
api/models/dataset.py @JohnJyong /api/models/dataset.py @JohnJyong
api/tasks/rag_pipeline/ @JohnJyong /api/tasks/rag_pipeline/ @JohnJyong
api/tasks/add_document_to_index_task.py @JohnJyong /api/tasks/add_document_to_index_task.py @JohnJyong
api/tasks/batch_clean_document_task.py @JohnJyong /api/tasks/batch_clean_document_task.py @JohnJyong
api/tasks/clean_document_task.py @JohnJyong /api/tasks/clean_document_task.py @JohnJyong
api/tasks/clean_notion_document_task.py @JohnJyong /api/tasks/clean_notion_document_task.py @JohnJyong
api/tasks/document_indexing_task.py @JohnJyong /api/tasks/document_indexing_task.py @JohnJyong
api/tasks/document_indexing_sync_task.py @JohnJyong /api/tasks/document_indexing_sync_task.py @JohnJyong
api/tasks/document_indexing_update_task.py @JohnJyong /api/tasks/document_indexing_update_task.py @JohnJyong
api/tasks/duplicate_document_indexing_task.py @JohnJyong /api/tasks/duplicate_document_indexing_task.py @JohnJyong
api/tasks/recover_document_indexing_task.py @JohnJyong /api/tasks/recover_document_indexing_task.py @JohnJyong
api/tasks/remove_document_from_index_task.py @JohnJyong /api/tasks/remove_document_from_index_task.py @JohnJyong
api/tasks/retry_document_indexing_task.py @JohnJyong /api/tasks/retry_document_indexing_task.py @JohnJyong
api/tasks/sync_website_document_indexing_task.py @JohnJyong /api/tasks/sync_website_document_indexing_task.py @JohnJyong
api/tasks/batch_create_segment_to_index_task.py @JohnJyong /api/tasks/batch_create_segment_to_index_task.py @JohnJyong
api/tasks/create_segment_to_index_task.py @JohnJyong /api/tasks/create_segment_to_index_task.py @JohnJyong
api/tasks/delete_segment_from_index_task.py @JohnJyong /api/tasks/delete_segment_from_index_task.py @JohnJyong
api/tasks/disable_segment_from_index_task.py @JohnJyong /api/tasks/disable_segment_from_index_task.py @JohnJyong
api/tasks/disable_segments_from_index_task.py @JohnJyong /api/tasks/disable_segments_from_index_task.py @JohnJyong
api/tasks/enable_segment_to_index_task.py @JohnJyong /api/tasks/enable_segment_to_index_task.py @JohnJyong
api/tasks/enable_segments_to_index_task.py @JohnJyong /api/tasks/enable_segments_to_index_task.py @JohnJyong
api/tasks/clean_dataset_task.py @JohnJyong /api/tasks/clean_dataset_task.py @JohnJyong
api/tasks/deal_dataset_index_update_task.py @JohnJyong /api/tasks/deal_dataset_index_update_task.py @JohnJyong
api/tasks/deal_dataset_vector_index_task.py @JohnJyong /api/tasks/deal_dataset_vector_index_task.py @JohnJyong
# Backend - Plugins # Backend - Plugins
api/core/plugin/ @Mairuis @Yeuoly @Stream29 /api/core/plugin/ @Mairuis @Yeuoly @Stream29
api/services/plugin/ @Mairuis @Yeuoly @Stream29 /api/services/plugin/ @Mairuis @Yeuoly @Stream29
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29 /api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29 /api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29 /api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
# Backend - Trigger/Schedule/Webhook # Backend - Trigger/Schedule/Webhook
api/controllers/trigger/ @Mairuis @Yeuoly /api/controllers/trigger/ @Mairuis @Yeuoly
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly /api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly /api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
api/core/trigger/ @Mairuis @Yeuoly /api/core/trigger/ @Mairuis @Yeuoly
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly /api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
api/services/trigger/ @Mairuis @Yeuoly /api/services/trigger/ @Mairuis @Yeuoly
api/models/trigger.py @Mairuis @Yeuoly /api/models/trigger.py @Mairuis @Yeuoly
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly /api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly /api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly /api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/libs/schedule_utils.py @Mairuis @Yeuoly /api/libs/schedule_utils.py @Mairuis @Yeuoly
api/services/workflow/scheduler.py @Mairuis @Yeuoly /api/services/workflow/scheduler.py @Mairuis @Yeuoly
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly /api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly /api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly /api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly /api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly /api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly /api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly /api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly /api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly /api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly /api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
# Backend - Async Workflow # Backend - Async Workflow
api/services/async_workflow_service.py @Mairuis @Yeuoly /api/services/async_workflow_service.py @Mairuis @Yeuoly
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly /api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
# Backend - Billing # Backend - Billing
api/services/billing_service.py @hj24 @zyssyz123 /api/services/billing_service.py @hj24 @zyssyz123
api/controllers/console/billing/ @hj24 @zyssyz123 /api/controllers/console/billing/ @hj24 @zyssyz123
# Backend - Enterprise # Backend - Enterprise
api/configs/enterprise/ @GarfieldDai @GareArc /api/configs/enterprise/ @GarfieldDai @GareArc
api/services/enterprise/ @GarfieldDai @GareArc /api/services/enterprise/ @GarfieldDai @GareArc
api/services/feature_service.py @GarfieldDai @GareArc /api/services/feature_service.py @GarfieldDai @GareArc
api/controllers/console/feature.py @GarfieldDai @GareArc /api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc /api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations # Backend - Database Migrations
api/migrations/ @snakevash @laipz8200 /api/migrations/ @snakevash @laipz8200 @MRZHUH
# Backend - Vector DB Middleware
/api/configs/middleware/vdb/* @JohnJyong
# Frontend # Frontend
web/ @iamjoel /web/ @iamjoel
# Frontend - Web Tests
/.github/workflows/web-tests.yml @iamjoel
# Frontend - App - Orchestration # Frontend - App - Orchestration
web/app/components/workflow/ @iamjoel @zxhlyh /web/app/components/workflow/ @iamjoel @zxhlyh
web/app/components/workflow-app/ @iamjoel @zxhlyh /web/app/components/workflow-app/ @iamjoel @zxhlyh
web/app/components/app/configuration/ @iamjoel @zxhlyh /web/app/components/app/configuration/ @iamjoel @zxhlyh
web/app/components/app/app-publisher/ @iamjoel @zxhlyh /web/app/components/app/app-publisher/ @iamjoel @zxhlyh
# Frontend - WebApp - Chat # Frontend - WebApp - Chat
web/app/components/base/chat/ @iamjoel @zxhlyh /web/app/components/base/chat/ @iamjoel @zxhlyh
# Frontend - WebApp - Completion # Frontend - WebApp - Completion
web/app/components/share/text-generation/ @iamjoel @zxhlyh /web/app/components/share/text-generation/ @iamjoel @zxhlyh
# Frontend - App - List and Creation # Frontend - App - List and Creation
web/app/components/apps/ @JzoNgKVO @iamjoel /web/app/components/apps/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel /web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel /web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel /web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
# Frontend - App - API Documentation # Frontend - App - API Documentation
web/app/components/develop/ @JzoNgKVO @iamjoel /web/app/components/develop/ @JzoNgKVO @iamjoel
# Frontend - App - Logs and Annotations # Frontend - App - Logs and Annotations
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel /web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
web/app/components/app/log/ @JzoNgKVO @iamjoel /web/app/components/app/log/ @JzoNgKVO @iamjoel
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel /web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
web/app/components/app/annotation/ @JzoNgKVO @iamjoel /web/app/components/app/annotation/ @JzoNgKVO @iamjoel
# Frontend - App - Monitoring # Frontend - App - Monitoring
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel /web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
web/app/components/app/overview/ @JzoNgKVO @iamjoel /web/app/components/app/overview/ @JzoNgKVO @iamjoel
# Frontend - App - Settings # Frontend - App - Settings
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel /web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
# Frontend - RAG - Hit Testing # Frontend - RAG - Hit Testing
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel /web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
# Frontend - RAG - List and Creation # Frontend - RAG - List and Creation
web/app/components/datasets/list/ @iamjoel @WTW0313 /web/app/components/datasets/list/ @iamjoel @WTW0313
web/app/components/datasets/create/ @iamjoel @WTW0313 /web/app/components/datasets/create/ @iamjoel @WTW0313
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313 /web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313 /web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
# Frontend - RAG - Orchestration (general rule first, specific rules below override) # Frontend - RAG - Orchestration (general rule first, specific rules below override)
web/app/components/rag-pipeline/ @iamjoel @WTW0313 /web/app/components/rag-pipeline/ @iamjoel @WTW0313
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh /web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh /web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
# Frontend - RAG - Documents List # Frontend - RAG - Documents List
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313 /web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313 /web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
# Frontend - RAG - Segments List # Frontend - RAG - Segments List
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313 /web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
# Frontend - RAG - Settings # Frontend - RAG - Settings
web/app/components/datasets/settings/ @iamjoel @WTW0313 /web/app/components/datasets/settings/ @iamjoel @WTW0313
# Frontend - Ecosystem - Plugins # Frontend - Ecosystem - Plugins
web/app/components/plugins/ @iamjoel @zhsama /web/app/components/plugins/ @iamjoel @zhsama
# Frontend - Ecosystem - Tools # Frontend - Ecosystem - Tools
web/app/components/tools/ @iamjoel @Yessenia-d /web/app/components/tools/ @iamjoel @Yessenia-d
# Frontend - Ecosystem - MarketPlace # Frontend - Ecosystem - MarketPlace
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d /web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
# Frontend - Login and Registration # Frontend - Login and Registration
web/app/signin/ @douxc @iamjoel /web/app/signin/ @douxc @iamjoel
web/app/signup/ @douxc @iamjoel /web/app/signup/ @douxc @iamjoel
web/app/reset-password/ @douxc @iamjoel /web/app/reset-password/ @douxc @iamjoel
web/app/install/ @douxc @iamjoel /web/app/install/ @douxc @iamjoel
web/app/init/ @douxc @iamjoel /web/app/init/ @douxc @iamjoel
web/app/forgot-password/ @douxc @iamjoel /web/app/forgot-password/ @douxc @iamjoel
web/app/account/ @douxc @iamjoel /web/app/account/ @douxc @iamjoel
# Frontend - Service Authentication # Frontend - Service Authentication
web/service/base.ts @douxc @iamjoel /web/service/base.ts @douxc @iamjoel
# Frontend - WebApp Authentication and Access Control # Frontend - WebApp Authentication and Access Control
web/app/(shareLayout)/components/ @douxc @iamjoel /web/app/(shareLayout)/components/ @douxc @iamjoel
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel /web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel /web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
web/app/components/app/app-access-control/ @douxc @iamjoel /web/app/components/app/app-access-control/ @douxc @iamjoel
# Frontend - Explore Page # Frontend - Explore Page
web/app/components/explore/ @CodingOnStar @iamjoel /web/app/components/explore/ @CodingOnStar @iamjoel
# Frontend - Personal Settings # Frontend - Personal Settings
web/app/components/header/account-setting/ @CodingOnStar @iamjoel /web/app/components/header/account-setting/ @CodingOnStar @iamjoel
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel /web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
# Frontend - Analytics # Frontend - Analytics
web/app/components/base/ga/ @CodingOnStar @iamjoel /web/app/components/base/ga/ @CodingOnStar @iamjoel
# Frontend - Base Components # Frontend - Base Components
web/app/components/base/ @iamjoel @zxhlyh /web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Utils and Hooks # Frontend - Utils and Hooks
web/utils/classnames.ts @iamjoel @zxhlyh /web/utils/classnames.ts @iamjoel @zxhlyh
web/utils/time.ts @iamjoel @zxhlyh /web/utils/time.ts @iamjoel @zxhlyh
web/utils/format.ts @iamjoel @zxhlyh /web/utils/format.ts @iamjoel @zxhlyh
web/utils/clipboard.ts @iamjoel @zxhlyh /web/utils/clipboard.ts @iamjoel @zxhlyh
web/hooks/use-document-title.ts @iamjoel @zxhlyh /web/hooks/use-document-title.ts @iamjoel @zxhlyh
# Frontend - Billing and Education # Frontend - Billing and Education
web/app/components/billing/ @iamjoel @zxhlyh /web/app/components/billing/ @iamjoel @zxhlyh
web/app/education-apply/ @iamjoel @zxhlyh /web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace # Frontend - Workspace
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh /web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
# Docker
/docker/* @laipz8200

View File

@ -1,8 +1,6 @@
name: "✨ Refactor" name: "✨ Refactor or Chore"
description: Refactor existing code for improved readability and maintainability. description: Refactor existing code or perform maintenance chores to improve readability and reliability.
title: "[Chore/Refactor] " title: "[Refactor/Chore] "
labels:
- refactor
body: body:
- type: checkboxes - type: checkboxes
attributes: attributes:
@ -11,7 +9,7 @@ body:
options: options:
- label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true required: true
- label: This is only for refactoring, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). - label: This is only for refactors or chores; if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true required: true
@ -25,14 +23,14 @@ body:
id: description id: description
attributes: attributes:
label: Description label: Description
placeholder: "Describe the refactor you are proposing." placeholder: "Describe the refactor or chore you are proposing."
validations: validations:
required: true required: true
- type: textarea - type: textarea
id: motivation id: motivation
attributes: attributes:
label: Motivation label: Motivation
placeholder: "Explain why this refactor is necessary." placeholder: "Explain why this refactor or chore is necessary."
validations: validations:
required: false required: false
- type: textarea - type: textarea

View File

@ -1,13 +0,0 @@
name: "👾 Tracker"
description: For inner usages, please do not use this template.
title: "[Tracker] "
labels:
- tracker
body:
- type: textarea
id: content
attributes:
label: Blockers
placeholder: "- [ ] ..."
validations:
required: true

View File

@ -1,12 +0,0 @@
# Copilot Instructions
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
Key reminders:
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
- Target >95% line and branch coverage and 100% function/statement coverage.
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.

View File

@ -71,18 +71,18 @@ jobs:
run: | run: |
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Run Workflow - name: Run API Tests
run: uv run --project api bash dev/pytest/pytest_workflow.sh env:
STORAGE_TYPE: opendal
- name: Run Tool OPENDAL_SCHEME: fs
run: uv run --project api bash dev/pytest/pytest_tools.sh OPENDAL_FS_ROOT: /tmp/dify-storage
- name: Run TestContainers
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
- name: Run Unit tests
run: | run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh uv run --project api pytest \
--timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \
api/tests/test_containers_integration_tests \
api/tests/unit_tests
- name: Coverage Summary - name: Coverage Summary
run: | run: |
@ -93,5 +93,12 @@ jobs:
# Create a detailed coverage summary # Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY {
echo ""
echo "<details><summary>File-level coverage (click to expand)</summary>"
echo ""
echo '```'
uv run --project api coverage report -m
echo '```'
echo "</details>"
} >> $GITHUB_STEP_SUMMARY

View File

@ -14,10 +14,27 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
# Use uv to ensure we have the same ruff version in CI and locally. - name: Check Docker Compose inputs
- uses: astral-sh/setup-uv@v6 id: docker-compose-changes
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- uses: actions/setup-python@v5
with: with:
python-version: "3.11" python-version: "3.11"
- uses: astral-sh/setup-uv@v6
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- run: | - run: |
cd api cd api
uv sync --dev uv sync --dev
@ -35,10 +52,11 @@ jobs:
- name: ast-grep - name: ast-grep
run: | run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all # ast-grep exits 1 if no matches are found; allow idempotent runs.
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all uvx --from ast-grep-cli ast-grep --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all || true
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all uvx --from ast-grep-cli ast-grep -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all || true
uvx --from ast-grep-cli ast-grep -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all || true
# Convert Optional[T] to T | None (ignoring quoted types) # Convert Optional[T] to T | None (ignoring quoted types)
cat > /tmp/optional-rule.yml << 'EOF' cat > /tmp/optional-rule.yml << 'EOF'
id: convert-optional-to-union id: convert-optional-to-union
@ -56,35 +74,14 @@ jobs:
pattern: $T pattern: $T
fix: $T | None fix: $T | None
EOF EOF
uvx --from ast-grep-cli sg scan --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all uvx --from ast-grep-cli ast-grep scan . --inline-rules "$(cat /tmp/optional-rule.yml)" --update-all
# Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax) # Fix forward references that were incorrectly converted (Python doesn't support "Type" | None syntax)
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete find . -name "*.py.bak" -type f -delete
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat - name: mdformat
run: | run: |
uvx mdformat . uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v4
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/package.json
- name: Web dependencies
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: oxlint
working-directory: ./web
run: |
pnpx oxlint --fix
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

View File

@ -0,0 +1,21 @@
name: Semantic Pull Request
on:
pull_request:
types:
- opened
- edited
- reopened
- synchronize
jobs:
lint:
name: Validate PR title
permissions:
pull-requests: read
runs-on: ubuntu-latest
steps:
- name: Check title
uses: amannn/action-semantic-pull-request@v6.1.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -90,7 +90,7 @@ jobs:
with: with:
node-version: 22 node-version: 22
cache: pnpm cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies - name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
@ -106,37 +106,7 @@ jobs:
- name: Web type check - name: Web type check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web working-directory: ./web
run: pnpm run type-check run: pnpm run type-check:tsgo
docker-compose-template:
name: Docker Compose Template
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- name: Generate Docker Compose
if: steps.changed-files.outputs.any_changed == 'true'
run: |
cd docker
./generate_docker_compose
- name: Check for changes
if: steps.changed-files.outputs.any_changed == 'true'
run: git diff --exit-code
superlinter: superlinter:
name: SuperLinter name: SuperLinter

View File

@ -1,4 +1,4 @@
name: Check i18n Files and Create PR name: Translate i18n Files Based on English
on: on:
push: push:
@ -55,7 +55,7 @@ jobs:
with: with:
node-version: 'lts/*' node-version: 'lts/*'
cache: pnpm cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies - name: Install dependencies
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
@ -67,25 +67,19 @@ jobs:
working-directory: ./web working-directory: ./web
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }} run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
- name: Generate i18n type definitions
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run gen:i18n-types
- name: Create Pull Request - name: Create Pull Request
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6 uses: peter-evans/create-pull-request@v6
with: with:
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore(i18n): update translations based on en-US changes' commit-message: 'chore(i18n): update translations based on en-US changes'
title: 'chore(i18n): translate i18n files and update type definitions' title: 'chore(i18n): translate i18n files based on en-US changes'
body: | body: |
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale. This PR was automatically created to update i18n translation files based on changes in en-US locale.
**Triggered by:** ${{ github.sha }} **Triggered by:** ${{ github.sha }}
**Changes included:** **Changes included:**
- Updated translation files for all locales - Updated translation files for all locales
- Regenerated TypeScript type definitions for type safety
branch: chore/automated-i18n-updates-${{ github.sha }} branch: chore/automated-i18n-updates-${{ github.sha }}
delete-branch: true delete-branch: true

View File

@ -13,6 +13,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults: defaults:
run: run:
shell: bash
working-directory: ./web working-directory: ./web
steps: steps:
@ -21,14 +22,7 @@ jobs:
with: with:
persist-credentials: false persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: web/**
- name: Install pnpm - name: Install pnpm
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4 uses: pnpm/action-setup@v4
with: with:
package_json_file: web/package.json package_json_file: web/package.json
@ -36,23 +30,339 @@ jobs:
- name: Setup Node.js - name: Setup Node.js
uses: actions/setup-node@v4 uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 22 node-version: 22
cache: pnpm cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies - name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run check:i18n-types
- name: Run tests - name: Run tests
if: steps.changed-files.outputs.any_changed == 'true' run: pnpm test:coverage
working-directory: ./web
run: pnpm test - name: Coverage Summary
if: always()
id: coverage-summary
run: |
set -eo pipefail
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
exit 0
fi
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
const fs = require('fs');
const path = require('path');
let libCoverage = null;
try {
libCoverage = require('istanbul-lib-coverage');
} catch (error) {
libCoverage = null;
}
const summaryPath = path.join('coverage', 'coverage-summary.json');
const finalPath = path.join('coverage', 'coverage-final.json');
const hasSummary = fs.existsSync(summaryPath);
const hasFinal = fs.existsSync(finalPath);
if (!hasSummary && !hasFinal) {
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('No coverage data found.');
process.exit(0);
}
const summary = hasSummary
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
: null;
const coverage = hasFinal
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
: null;
const getLineCoverageFromStatements = (statementMap, statementHits) => {
const lineHits = {};
if (!statementMap || !statementHits) {
return lineHits;
}
Object.entries(statementMap).forEach(([key, statement]) => {
const line = statement?.start?.line;
if (!line) {
return;
}
const hits = statementHits[key] ?? 0;
const previous = lineHits[line];
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
});
return lineHits;
};
const getFileCoverage = (entry) => (
libCoverage ? libCoverage.createFileCoverage(entry) : null
);
const getLineHits = (entry, fileCoverage) => {
const lineHits = entry.l ?? {};
if (Object.keys(lineHits).length > 0) {
return lineHits;
}
if (fileCoverage) {
return fileCoverage.getLineCoverage();
}
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
};
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
if (lineHits && Object.keys(lineHits).length > 0) {
return Object.entries(lineHits)
.filter(([, count]) => count === 0)
.map(([line]) => Number(line))
.sort((a, b) => a - b);
}
if (fileCoverage) {
return fileCoverage.getUncoveredLines();
}
return [];
};
const totals = {
lines: { covered: 0, total: 0 },
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
};
const fileSummaries = [];
if (summary) {
const totalEntry = summary.total ?? {};
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
if (totalEntry[key]) {
totals[key].covered = totalEntry[key].covered ?? 0;
totals[key].total = totalEntry[key].total ?? 0;
}
});
Object.entries(summary)
.filter(([file]) => file !== 'total')
.forEach(([file, data]) => {
fileSummaries.push({
file,
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
lines: {
covered: data.lines?.covered ?? 0,
total: data.lines?.total ?? 0,
},
});
});
} else if (coverage) {
Object.entries(coverage).forEach(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
totals.lines.total += lineTotal;
totals.lines.covered += lineCovered;
totals.statements.total += statementTotal;
totals.statements.covered += statementCovered;
totals.branches.total += branchTotal;
totals.branches.covered += branchCovered;
totals.functions.total += functionTotal;
totals.functions.covered += functionCovered;
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
fileSummaries.push({
file,
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
lines: {
covered: lineCovered || statementCovered,
total: lineTotal || statementTotal,
},
});
});
}
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('| Metric | Coverage | Covered / Total |');
console.log('|--------|----------|-----------------|');
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
console.log('');
console.log('<details><summary>File coverage (lowest lines first)</summary>');
console.log('');
console.log('```');
fileSummaries
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
.slice(0, 25)
.forEach(({ file, pct, lines }) => {
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
});
console.log('```');
console.log('</details>');
if (coverage) {
const pctValue = (covered, tot) => {
if (tot === 0) {
return '0';
}
return ((covered / tot) * 100)
.toFixed(2)
.replace(/\.?0+$/, '');
};
const formatLineRanges = (lines) => {
if (lines.length === 0) {
return '';
}
const ranges = [];
let start = lines[0];
let end = lines[0];
for (let i = 1; i < lines.length; i += 1) {
const current = lines[i];
if (current === end + 1) {
end = current;
continue;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
start = current;
end = current;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
return ranges.join(',');
};
const tableTotals = {
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
lines: { covered: 0, total: 0 },
};
const tableRows = Object.entries(coverage)
.map(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
tableTotals.lines.total += lineTotal;
tableTotals.lines.covered += lineCovered;
tableTotals.statements.total += statementTotal;
tableTotals.statements.covered += statementCovered;
tableTotals.branches.total += branchTotal;
tableTotals.branches.covered += branchCovered;
tableTotals.functions.total += functionTotal;
tableTotals.functions.covered += functionCovered;
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
const filePath = entry.path ?? file;
const relativePath = path.isAbsolute(filePath)
? path.relative(process.cwd(), filePath)
: filePath;
return {
file: relativePath || file,
statements: pctValue(statementCovered, statementTotal),
branches: pctValue(branchCovered, branchTotal),
functions: pctValue(functionCovered, functionTotal),
lines: pctValue(lineCovered, lineTotal),
uncovered: formatLineRanges(uncoveredLines),
};
})
.sort((a, b) => a.file.localeCompare(b.file));
const columns = [
{ key: 'file', header: 'File', align: 'left' },
{ key: 'statements', header: '% Stmts', align: 'right' },
{ key: 'branches', header: '% Branch', align: 'right' },
{ key: 'functions', header: '% Funcs', align: 'right' },
{ key: 'lines', header: '% Lines', align: 'right' },
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
];
const allFilesRow = {
file: 'All files',
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
uncovered: '',
};
const rowsForOutput = [allFilesRow, ...tableRows];
const formatRow = (row) => `| ${columns
.map(({ key }) => String(row[key] ?? ''))
.join(' | ')} |`;
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
const dividerRow = `| ${columns
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
.join(' | ')} |`;
console.log('');
console.log('<details><summary>Vitest coverage table</summary>');
console.log('');
console.log(headerRow);
console.log(dividerRow);
rowsForOutput.forEach((row) => console.log(formatRow(row)));
console.log('</details>');
}
NODE
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v4
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error

13
.gitignore vendored
View File

@ -139,7 +139,6 @@ pyrightconfig.json
.idea/' .idea/'
.DS_Store .DS_Store
web/.vscode/settings.json
# Intellij IDEA Files # Intellij IDEA Files
.idea/* .idea/*
@ -189,12 +188,14 @@ docker/volumes/matrixone/*
docker/volumes/mysql/* docker/volumes/mysql/*
docker/volumes/seekdb/* docker/volumes/seekdb/*
!docker/volumes/oceanbase/init.d !docker/volumes/oceanbase/init.d
docker/volumes/iris/*
docker/nginx/conf.d/default.conf docker/nginx/conf.d/default.conf
docker/nginx/ssl/* docker/nginx/ssl/*
!docker/nginx/ssl/.gitkeep !docker/nginx/ssl/.gitkeep
docker/middleware.env docker/middleware.env
docker/docker-compose.override.yaml docker/docker-compose.override.yaml
docker/env-backup/*
sdks/python-client/build sdks/python-client/build
sdks/python-client/dist sdks/python-client/dist
@ -204,7 +205,6 @@ sdks/python-client/dify_client.egg-info
!.vscode/launch.json.template !.vscode/launch.json.template
!.vscode/README.md !.vscode/README.md
api/.vscode api/.vscode
web/.vscode
# vscode Code History Extension # vscode Code History Extension
.history .history
@ -219,15 +219,6 @@ plugins.jsonl
# mise # mise
mise.toml mise.toml
# Next.js build output
.next/
# PWA generated files
web/public/sw.js
web/public/sw.js.map
web/public/workbox-*.js
web/public/workbox-*.js.map
web/public/fallback-*.js
# AI Assistant # AI Assistant
.roo/ .roo/

1
.nvmrc Normal file
View File

@ -0,0 +1 @@
22.11.0

View File

@ -37,7 +37,7 @@
"-c", "-c",
"1", "1",
"-Q", "-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor", "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"--loglevel", "--loglevel",
"INFO" "INFO"
], ],

View File

@ -1,5 +0,0 @@
# Windsurf Testing Rules
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
- Honor every requirement in that document when generating or accepting tests.
- When proposing or saving tests, re-read that document and follow every requirement.

View File

@ -24,8 +24,8 @@ The codebase is split into:
```bash ```bash
cd web cd web
pnpm lint
pnpm lint:fix pnpm lint:fix
pnpm type-check:tsgo
pnpm test pnpm test
``` ```
@ -39,7 +39,7 @@ pnpm test
## Language Style ## Language Style
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). - **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types. - **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
## General Practices ## General Practices

View File

@ -139,6 +139,19 @@ Star Dify on GitHub and be instantly notified of new releases.
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
#### Customizing Suggested Questions
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
```bash
# In your .env file
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
SUGGESTED_QUESTIONS_MAX_TOKENS=512
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
```
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
### Metrics Monitoring with Grafana ### Metrics Monitoring with Grafana
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more. Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.

View File

@ -116,6 +116,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names. # Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path ALIYUN_OSS_PATH=your-path
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
# Google Storage configuration # Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
@ -133,6 +134,7 @@ HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url HUAWEI_OBS_SERVER=your-server-url
HUAWEI_OBS_PATH_STYLE=false
# Baidu OBS Storage Configuration # Baidu OBS Storage Configuration
BAIDU_OBS_BUCKET_NAME=your-bucket-name BAIDU_OBS_BUCKET_NAME=your-bucket-name
@ -543,6 +545,25 @@ APP_MAX_EXECUTION_TIME=1200
APP_DEFAULT_ACTIVE_REQUESTS=0 APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0
# Aliyun SLS Logstore Configuration
# Aliyun Access Key ID
ALIYUN_SLS_ACCESS_KEY_ID=
# Aliyun Access Key Secret
ALIYUN_SLS_ACCESS_KEY_SECRET=
# Aliyun SLS Endpoint (e.g., cn-hangzhou.log.aliyuncs.com)
ALIYUN_SLS_ENDPOINT=
# Aliyun SLS Region (e.g., cn-hangzhou)
ALIYUN_SLS_REGION=
# Aliyun SLS Project Name
ALIYUN_SLS_PROJECT_NAME=
# Number of days to retain workflow run logs (default: 365 days 3650 for permanent storage)
ALIYUN_SLS_LOGSTORE_TTL=365
# Enable dual-write to both SLS LogStore and SQL database (default: false)
LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# Celery beat configuration # Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1 CELERY_BEAT_SCHEDULER_TIME=1
@ -633,8 +654,45 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Set to false to export dataset IDs as plain text for easier cross-environment import # Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Suggested Questions After Answer Configuration
# These environment variables allow customization of the suggested questions feature
#
# Custom prompt for generating suggested questions (optional)
# If not set, uses the default prompt that generates 3 questions under 20 characters each
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
# SUGGESTED_QUESTIONS_PROMPT=
# Maximum number of tokens for suggested questions generation (default: 256)
# Adjust this value for longer questions or more questions
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
# Temperature for suggested questions generation (default: 0.0)
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
# SUGGESTED_QUESTIONS_TEMPERATURE=0
# Tenant isolated task queue configuration # Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1 TENANT_ISOLATED_TASK_CONCURRENCY=1
# Maximum number of segments for dataset segments API (0 for unlimited) # Maximum number of segments for dataset segments API (0 for unlimited)
DATASET_MAX_SEGMENTS_PER_REQUEST=0 DATASET_MAX_SEGMENTS_PER_REQUEST=0
# Multimodal knowledgebase limit
SINGLE_CHUNK_ATTACHMENT_LIMIT=10
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60
IMAGE_FILE_BATCH_LIMIT=10
# Maximum allowed CSV file size for annotation import in megabytes
ANNOTATION_IMPORT_FILE_SIZE_LIMIT=2
#Maximum number of annotation records allowed in a single import
ANNOTATION_IMPORT_MAX_RECORDS=10000
# Minimum number of annotation records required in a single import
ANNOTATION_IMPORT_MIN_RECORDS=1
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE=5
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR=20
# Maximum number of concurrent annotation import tasks per tenant
ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@ -36,17 +36,20 @@ select = [
"UP", # pyupgrade rules "UP", # pyupgrade rules
"W191", # tab-indentation "W191", # tab-indentation
"W605", # invalid-escape-sequence "W605", # invalid-escape-sequence
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum,
"S110", # disallow the try-except-pass pattern.
# security related linting rules # security related linting rules
# RCE proctection (sort of) # RCE proctection (sort of)
"S102", # exec-builtin, disallow use of `exec` "S102", # exec-builtin, disallow use of `exec`
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval` "S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module "S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage "S311", # suspicious-non-cryptographic-random-usage,
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
"UP042", # use StrEnum
] ]
ignore = [ ignore = [
@ -91,18 +94,16 @@ ignore = [
"configs/*" = [ "configs/*" = [
"N802", # invalid-function-name "N802", # invalid-function-name
] ]
"core/model_runtime/callbacks/base_callback.py" = [ "core/model_runtime/callbacks/base_callback.py" = ["T201"]
"T201", "core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
]
"core/workflow/callbacks/workflow_logging_callback.py" = [
"T201",
]
"libs/gmpy2_pkcs10aep_cipher.py" = [ "libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name "N803", # invalid-argument-name
] ]
"tests/*" = [ "tests/*" = [
"F811", # redefined-while-unused "F811", # redefined-while-unused
"T201", # allow print in tests "T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
] ]
[lint.pyflakes] [lint.pyflakes]

View File

@ -84,7 +84,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
``` ```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -1,6 +1,8 @@
import logging import logging
import time import time
from opentelemetry.trace import get_current_span
from configs import dify_config from configs import dify_config
from contexts.wrapper import RecyclableContextVar from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp from dify_app import DifyApp
@ -26,8 +28,25 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request # add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles() RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
@dify_app.after_request
def add_trace_id_header(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction # Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request _ = before_request
_ = add_trace_id_header
return dify_app return dify_app
@ -51,10 +70,12 @@ def initialize_extensions(app: DifyApp):
ext_commands, ext_commands,
ext_compress, ext_compress,
ext_database, ext_database,
ext_forward_refs,
ext_hosting_provider, ext_hosting_provider,
ext_import_modules, ext_import_modules,
ext_logging, ext_logging,
ext_login, ext_login,
ext_logstore,
ext_mail, ext_mail,
ext_migrate, ext_migrate,
ext_orjson, ext_orjson,
@ -63,6 +84,7 @@ def initialize_extensions(app: DifyApp):
ext_redis, ext_redis,
ext_request_logging, ext_request_logging,
ext_sentry, ext_sentry,
ext_session_factory,
ext_set_secretkey, ext_set_secretkey,
ext_storage, ext_storage,
ext_timezone, ext_timezone,
@ -75,6 +97,7 @@ def initialize_extensions(app: DifyApp):
ext_warnings, ext_warnings,
ext_import_modules, ext_import_modules,
ext_orjson, ext_orjson,
ext_forward_refs,
ext_set_secretkey, ext_set_secretkey,
ext_compress, ext_compress,
ext_code_based_extension, ext_code_based_extension,
@ -83,6 +106,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate, ext_migrate,
ext_redis, ext_redis,
ext_storage, ext_storage,
ext_logstore, # Initialize logstore after storage, before celery
ext_celery, ext_celery,
ext_login, ext_login,
ext_mail, ext_mail,
@ -93,6 +117,7 @@ def initialize_extensions(app: DifyApp):
ext_commands, ext_commands,
ext_otel, ext_otel,
ext_request_logging, ext_request_logging,
ext_session_factory,
] ]
for ext in extensions: for ext in extensions:
short_name = ext.__name__.split(".")[-1] short_name = ext.__name__.split(".")[-1]

View File

@ -1139,6 +1139,7 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
except Exception as e: except Exception as e:
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
return
all_files_on_storage = [] all_files_on_storage = []
for storage_path in storage_paths: for storage_path in storage_paths:

View File

@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field( PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)", description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=300.0, default=600.0,
) )
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key") INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
@ -360,6 +360,57 @@ class FileUploadConfig(BaseSettings):
default=10, default=10,
) )
IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a image batch upload operation",
default=10,
)
SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a single chunk attachment",
default=10,
)
ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed image file size for attachments in megabytes",
default=2,
)
ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field(
description="Timeout for downloading image attachments in seconds",
default=60,
)
# Annotation Import Security Configurations
ANNOTATION_IMPORT_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="Maximum allowed CSV file size for annotation import in megabytes",
default=2,
)
ANNOTATION_IMPORT_MAX_RECORDS: PositiveInt = Field(
description="Maximum number of annotation records allowed in a single import",
default=10000,
)
ANNOTATION_IMPORT_MIN_RECORDS: PositiveInt = Field(
description="Minimum number of annotation records required in a single import",
default=1,
)
ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of annotation import requests per minute per tenant",
default=5,
)
ANNOTATION_IMPORT_RATE_LIMIT_PER_HOUR: PositiveInt = Field(
description="Maximum number of annotation import requests per hour per tenant",
default=20,
)
ANNOTATION_IMPORT_MAX_CONCURRENT: PositiveInt = Field(
description="Maximum number of concurrent annotation import tasks per tenant",
default=2,
)
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field( inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
description=( description=(
"Comma-separated list of file extensions that are blocked from upload. " "Comma-separated list of file extensions that are blocked from upload. "
@ -553,7 +604,10 @@ class LoggingConfig(BaseSettings):
LOG_FORMAT: str = Field( LOG_FORMAT: str = Field(
description="Format string for log messages", description="Format string for log messages",
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", default=(
"%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] "
"[%(filename)s:%(lineno)d] %(trace_id)s - %(message)s"
),
) )
LOG_DATEFORMAT: str | None = Field( LOG_DATEFORMAT: str | None = Field(
@ -1216,6 +1270,21 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
) )
class SandboxExpiredRecordsCleanConfig(BaseSettings):
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
description="Graceful period in days for sandbox records clean after subscription expiration",
default=21,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order
AppExecutionConfig, AppExecutionConfig,
@ -1241,6 +1310,7 @@ class FeatureConfig(
PositionConfig, PositionConfig,
RagEtlConfig, RagEtlConfig,
RepositoryConfig, RepositoryConfig,
SandboxExpiredRecordsCleanConfig,
SecurityConfig, SecurityConfig,
TenantIsolatedTaskQueueConfig, TenantIsolatedTaskQueueConfig,
ToolConfig, ToolConfig,

View File

@ -26,6 +26,7 @@ from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.iris_config import IrisVectorConfig
from .vdb.lindorm_config import LindormConfig from .vdb.lindorm_config import LindormConfig
from .vdb.matrixone_config import MatrixoneConfig from .vdb.matrixone_config import MatrixoneConfig
from .vdb.milvus_config import MilvusConfig from .vdb.milvus_config import MilvusConfig
@ -106,7 +107,7 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings): class DatabaseConfig(BaseSettings):
# Database type selector # Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field( DB_TYPE: Literal["postgresql", "mysql", "oceanbase", "seekdb"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.", description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql", default="postgresql",
) )
@ -336,6 +337,7 @@ class MiddlewareConfig(
ChromaConfig, ChromaConfig,
ClickzettaConfig, ClickzettaConfig,
HuaweiCloudConfig, HuaweiCloudConfig,
IrisVectorConfig,
MilvusConfig, MilvusConfig,
AlibabaCloudMySQLConfig, AlibabaCloudMySQLConfig,
MyScaleConfig, MyScaleConfig,

View File

@ -41,3 +41,8 @@ class AliyunOSSStorageConfig(BaseSettings):
description="Base path within the bucket to store objects (e.g., 'my-app-data/')", description="Base path within the bucket to store objects (e.g., 'my-app-data/')",
default=None, default=None,
) )
ALIYUN_CLOUDBOX_ID: str | None = Field(
description="Cloudbox id for aliyun cloudbox service",
default=None,
)

View File

@ -26,3 +26,8 @@ class HuaweiCloudOBSStorageConfig(BaseSettings):
description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')", description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')",
default=None, default=None,
) )
HUAWEI_OBS_PATH_STYLE: bool = Field(
description="Flag to indicate whether to use path-style URLs for OBS requests",
default=False,
)

View File

@ -0,0 +1,91 @@
"""Configuration for InterSystems IRIS vector database."""
from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings
class IrisVectorConfig(BaseSettings):
"""Configuration settings for IRIS vector database connection and pooling."""
IRIS_HOST: str | None = Field(
description="Hostname or IP address of the IRIS server.",
default="localhost",
)
IRIS_SUPER_SERVER_PORT: PositiveInt | None = Field(
description="Port number for IRIS connection.",
default=1972,
)
IRIS_USER: str | None = Field(
description="Username for IRIS authentication.",
default="_SYSTEM",
)
IRIS_PASSWORD: str | None = Field(
description="Password for IRIS authentication.",
default="Dify@1234",
)
IRIS_SCHEMA: str | None = Field(
description="Schema name for IRIS tables.",
default="dify",
)
IRIS_DATABASE: str | None = Field(
description="Database namespace for IRIS connection.",
default="USER",
)
IRIS_CONNECTION_URL: str | None = Field(
description="Full connection URL for IRIS (overrides individual fields if provided).",
default=None,
)
IRIS_MIN_CONNECTION: PositiveInt = Field(
description="Minimum number of connections in the pool.",
default=1,
)
IRIS_MAX_CONNECTION: PositiveInt = Field(
description="Maximum number of connections in the pool.",
default=3,
)
IRIS_TEXT_INDEX: bool = Field(
description="Enable full-text search index using %iFind.Index.Basic.",
default=True,
)
IRIS_TEXT_INDEX_LANGUAGE: str = Field(
description="Language for full-text search index (e.g., 'en', 'ja', 'zh', 'de').",
default="en",
)
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""Validate IRIS configuration values.
Args:
values: Configuration dictionary
Returns:
Validated configuration dictionary
Raises:
ValueError: If required fields are missing or pool settings are invalid
"""
# Only validate required fields if IRIS is being used as the vector store
# This allows the config to be loaded even when IRIS is not in use
# vector_store = os.environ.get("VECTOR_STORE", "")
# We rely on Pydantic defaults for required fields if they are missing from env.
# Strict existence check is removed to allow defaults to work.
min_conn = values.get("IRIS_MIN_CONNECTION", 1)
max_conn = values.get("IRIS_MAX_CONNECTION", 3)
if min_conn > max_conn:
raise ValueError("IRIS_MIN_CONNECTION must be less than or equal to IRIS_MAX_CONNECTION")
return values

View File

@ -20,6 +20,7 @@ language_timezone_mapping = {
"sl-SI": "Europe/Ljubljana", "sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok", "th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta", "id-ID": "Asia/Jakarta",
"ar-TN": "Africa/Tunis",
} }
languages = list(language_timezone_mapping.keys()) languages = list(language_timezone_mapping.keys())

View File

@ -0,0 +1,57 @@
import os
from email.message import Message
from urllib.parse import quote
from flask import Response
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
HTML_EXTENSIONS = frozenset({"html", "htm"})
def _normalize_mime_type(mime_type: str | None) -> str:
if not mime_type:
return ""
message = Message()
message["Content-Type"] = mime_type
return message.get_content_type().strip().lower()
def _is_html_extension(extension: str | None) -> bool:
if not extension:
return False
return extension.lstrip(".").lower() in HTML_EXTENSIONS
def is_html_content(mime_type: str | None, filename: str | None, extension: str | None = None) -> bool:
normalized_mime_type = _normalize_mime_type(mime_type)
if normalized_mime_type in HTML_MIME_TYPES:
return True
if _is_html_extension(extension):
return True
if filename:
return _is_html_extension(os.path.splitext(filename)[1])
return False
def enforce_download_for_html(
response: Response,
*,
mime_type: str | None,
filename: str | None,
extension: str | None = None,
) -> bool:
if not is_html_content(mime_type, filename, extension):
return False
if filename:
encoded_filename = quote(filename)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
else:
response.headers["Content-Disposition"] = "attachment"
response.headers["Content-Type"] = "application/octet-stream"
response.headers["X-Content-Type-Options"] = "nosniff"
return True

View File

@ -0,0 +1,26 @@
"""Helpers for registering Pydantic models with Flask-RESTX namespaces."""
from flask_restx import Namespace
from pydantic import BaseModel
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
"""Register a single BaseModel with a namespace for Swagger documentation."""
namespace.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> None:
"""Register multiple BaseModels with a namespace."""
for model in models:
register_schema_model(namespace, model)
__all__ = [
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
"register_schema_model",
"register_schema_models",
]

View File

@ -3,21 +3,47 @@ from functools import wraps
from typing import ParamSpec, TypeVar from typing import ParamSpec, TypeVar
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import only_edition_cloud from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db from extensions.ext_database import db
from libs.token import extract_access_token from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InsertExploreAppPayload(BaseModel):
app_id: str = Field(...)
desc: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]): def admin_required(view: Callable[P, R]):
@wraps(view) @wraps(view)
@ -40,59 +66,34 @@ def admin_required(view: Callable[P, R]):
class InsertExploreAppListApi(Resource): class InsertExploreAppListApi(Resource):
@console_ns.doc("insert_explore_app") @console_ns.doc("insert_explore_app")
@console_ns.doc(description="Insert or update an app in the explore list") @console_ns.doc(description="Insert or update an app in the explore list")
@console_ns.expect( @console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
console_ns.model(
"InsertExploreAppRequest",
{
"app_id": fields.String(required=True, description="Application ID"),
"desc": fields.String(description="App description"),
"copyright": fields.String(description="Copyright information"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"language": fields.String(required=True, description="Language code"),
"category": fields.String(required=True, description="App category"),
"position": fields.Integer(required=True, description="Display position"),
},
)
)
@console_ns.response(200, "App updated successfully") @console_ns.response(200, "App updated successfully")
@console_ns.response(201, "App inserted successfully") @console_ns.response(201, "App inserted successfully")
@console_ns.response(404, "App not found") @console_ns.response(404, "App not found")
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def post(self): def post(self):
parser = ( payload = InsertExploreAppPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
.add_argument("desc", type=str, location="json")
.add_argument("copyright", type=str, location="json")
.add_argument("privacy_policy", type=str, location="json")
.add_argument("custom_disclaimer", type=str, location="json")
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("category", type=str, required=True, nullable=False, location="json")
.add_argument("position", type=int, required=True, nullable=False, location="json")
)
args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
if not app: if not app:
raise NotFound(f"App '{args['app_id']}' is not found") raise NotFound(f"App '{payload.app_id}' is not found")
site = app.site site = app.site
if not site: if not site:
desc = args["desc"] or "" desc = payload.desc or ""
copy_right = args["copyright"] or "" copy_right = payload.copyright or ""
privacy_policy = args["privacy_policy"] or "" privacy_policy = payload.privacy_policy or ""
custom_disclaimer = args["custom_disclaimer"] or "" custom_disclaimer = payload.custom_disclaimer or ""
else: else:
desc = site.description or args["desc"] or "" desc = site.description or payload.desc or ""
copy_right = site.copyright or args["copyright"] or "" copy_right = site.copyright or payload.copyright or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or "" privacy_policy = site.privacy_policy or payload.privacy_policy or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
with Session(db.engine) as session: with session_factory.create_session() as session:
recommended_app = session.execute( recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]) select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
).scalar_one_or_none() ).scalar_one_or_none()
if not recommended_app: if not recommended_app:
@ -102,9 +103,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right, copyright=copy_right,
privacy_policy=privacy_policy, privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer, custom_disclaimer=custom_disclaimer,
language=args["language"], language=payload.language,
category=args["category"], category=payload.category,
position=args["position"], position=payload.position,
) )
db.session.add(recommended_app) db.session.add(recommended_app)
@ -118,9 +119,9 @@ class InsertExploreAppListApi(Resource):
recommended_app.copyright = copy_right recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args["language"] recommended_app.language = payload.language
recommended_app.category = args["category"] recommended_app.category = payload.category
recommended_app.position = args["position"] recommended_app.position = payload.position
app.is_public = True app.is_public = True
@ -138,7 +139,7 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def delete(self, app_id): def delete(self, app_id):
with Session(db.engine) as session: with session_factory.create_session() as session:
recommended_app = session.execute( recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none() ).scalar_one_or_none()
@ -146,13 +147,13 @@ class InsertExploreAppApi(Resource):
if not recommended_app: if not recommended_app:
return {"result": "success"}, 204 return {"result": "success"}, 204
with Session(db.engine) as session: with session_factory.create_session() as session:
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none() app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app: if app:
app.is_public = False app.is_public = False
with Session(db.engine) as session: with session_factory.create_session() as session:
installed_apps = ( installed_apps = (
session.execute( session.execute(
select(InstalledApp).where( select(InstalledApp).where(

View File

@ -1,16 +1,23 @@
from flask_restx import Resource, fields, reqparse from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService from services.advanced_prompt_template_service import AdvancedPromptTemplateService
parser = (
reqparse.RequestParser() class AdvancedPromptTemplateQuery(BaseModel):
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode") app_mode: str = Field(..., description="Application mode")
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode") model_mode: str = Field(..., description="Model mode")
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context") has_context: str = Field(default="true", description="Whether has context")
.add_argument("model_name", type=str, required=True, location="args", help="Model name") model_name: str = Field(..., description="Model name")
console_ns.schema_model(
AdvancedPromptTemplateQuery.__name__,
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
) )
@ -18,7 +25,7 @@ parser = (
class AdvancedPromptTemplateList(Resource): class AdvancedPromptTemplateList(Resource):
@console_ns.doc("get_advanced_prompt_templates") @console_ns.doc("get_advanced_prompt_templates")
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration") @console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
@console_ns.expect(parser) @console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
@console_ns.response( @console_ns.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data")) 200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
) )
@ -27,6 +34,6 @@ class AdvancedPromptTemplateList(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
args = parser.parse_args() args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AdvancedPromptTemplateService.get_prompt(args) return AdvancedPromptTemplateService.get_prompt(args.model_dump())

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, reqparse from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -8,10 +10,21 @@ from libs.login import login_required
from models.model import AppMode from models.model import AppMode
from services.agent_service import AgentService from services.agent_service import AgentService
parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID") class AgentLogQuery(BaseModel):
message_id: str = Field(..., description="Message UUID")
conversation_id: str = Field(..., description="Conversation UUID")
@field_validator("message_id", "conversation_id")
@classmethod
def validate_uuid(cls, value: str) -> str:
return uuid_value(value)
console_ns.schema_model(
AgentLogQuery.__name__, AgentLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
) )
@ -20,7 +33,7 @@ class AgentLogApi(Resource):
@console_ns.doc("get_agent_logs") @console_ns.doc("get_agent_logs")
@console_ns.doc(description="Get agent execution logs for an application") @console_ns.doc(description="Get agent execution logs for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[AgentLogQuery.__name__])
@console_ns.response( @console_ns.response(
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")) 200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
) )
@ -31,6 +44,6 @@ class AgentLogApi(Resource):
@get_app_model(mode=[AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.AGENT_CHAT])
def get(self, app_model): def get(self, app_model):
"""Get agent logs""" """Get agent logs"""
args = parser.parse_args() args = AgentLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) return AgentService.get_agent_logs(app_model, args.conversation_id, args.message_id)

View File

@ -1,12 +1,15 @@
from typing import Literal from typing import Any, Literal
from flask import request from flask import abort, make_response, request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
annotation_import_concurrency_limit,
annotation_import_rate_limit,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required, edit_permission_required,
setup_required, setup_required,
@ -21,22 +24,79 @@ from libs.helper import uuid_value
from libs.login import login_required from libs.login import login_required
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AnnotationReplyPayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold for annotation matching")
embedding_provider_name: str = Field(..., description="Embedding provider name")
embedding_model_name: str = Field(..., description="Embedding model name")
class AnnotationSettingUpdatePayload(BaseModel):
score_threshold: float = Field(..., description="Score threshold")
class AnnotationListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, description="Page size")
keyword: str = Field(default="", description="Search keyword")
class CreateAnnotationPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
question: str | None = Field(default=None, description="Question text")
answer: str | None = Field(default=None, description="Answer text")
content: str | None = Field(default=None, description="Content text")
annotation_reply: dict[str, Any] | None = Field(default=None, description="Annotation reply data")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class UpdateAnnotationPayload(BaseModel):
question: str | None = None
answer: str | None = None
content: str | None = None
annotation_reply: dict[str, Any] | None = None
class AnnotationReplyStatusQuery(BaseModel):
action: Literal["enable", "disable"]
class AnnotationFilePayload(BaseModel):
message_id: str = Field(..., description="Message ID")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
def reg(model: type[BaseModel]) -> None:
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AnnotationReplyPayload)
reg(AnnotationSettingUpdatePayload)
reg(AnnotationListQuery)
reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>") @console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource): class AnnotationReplyActionApi(Resource):
@console_ns.doc("annotation_reply_action") @console_ns.doc("annotation_reply_action")
@console_ns.doc(description="Enable or disable annotation reply for an app") @console_ns.doc(description="Enable or disable annotation reply for an app")
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"}) @console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@console_ns.expect( @console_ns.expect(console_ns.models[AnnotationReplyPayload.__name__])
console_ns.model(
"AnnotationReplyActionRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider name"),
"embedding_model_name": fields.String(required=True, description="Embedding model name"),
},
)
)
@console_ns.response(200, "Action completed successfully") @console_ns.response(200, "Action completed successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -46,15 +106,9 @@ class AnnotationReplyActionApi(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_id, action: Literal["enable", "disable"]): def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id) app_id = str(app_id)
parser = ( args = AnnotationReplyPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("score_threshold", required=True, type=float, location="json")
.add_argument("embedding_provider_name", required=True, type=str, location="json")
.add_argument("embedding_model_name", required=True, type=str, location="json")
)
args = parser.parse_args()
if action == "enable": if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_id) result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable": elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id) result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200 return result, 200
@ -82,16 +136,7 @@ class AppAnnotationSettingUpdateApi(Resource):
@console_ns.doc("update_annotation_setting") @console_ns.doc("update_annotation_setting")
@console_ns.doc(description="Update annotation settings for an app") @console_ns.doc(description="Update annotation settings for an app")
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"}) @console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AnnotationSettingUpdatePayload.__name__])
console_ns.model(
"AnnotationSettingUpdateRequest",
{
"score_threshold": fields.Float(required=True, description="Score threshold"),
"embedding_provider_name": fields.String(required=True, description="Embedding provider"),
"embedding_model_name": fields.String(required=True, description="Embedding model"),
},
)
)
@console_ns.response(200, "Settings updated successfully") @console_ns.response(200, "Settings updated successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -102,10 +147,9 @@ class AppAnnotationSettingUpdateApi(Resource):
app_id = str(app_id) app_id = str(app_id)
annotation_setting_id = str(annotation_setting_id) annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json") args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
return result, 200 return result, 200
@ -142,12 +186,7 @@ class AnnotationApi(Resource):
@console_ns.doc("list_annotations") @console_ns.doc("list_annotations")
@console_ns.doc(description="Get annotations for an app with pagination") @console_ns.doc(description="Get annotations for an app with pagination")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AnnotationListQuery.__name__])
console_ns.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size")
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
)
@console_ns.response(200, "Annotations retrieved successfully") @console_ns.response(200, "Annotations retrieved successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -155,9 +194,10 @@ class AnnotationApi(Resource):
@account_initialization_required @account_initialization_required
@edit_permission_required @edit_permission_required
def get(self, app_id): def get(self, app_id):
page = request.args.get("page", default=1, type=int) args = AnnotationListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
limit = request.args.get("limit", default=20, type=int) page = args.page
keyword = request.args.get("keyword", default="", type=str) limit = args.limit
keyword = args.keyword
app_id = str(app_id) app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
@ -173,18 +213,7 @@ class AnnotationApi(Resource):
@console_ns.doc("create_annotation") @console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app") @console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
console_ns.model(
"CreateAnnotationRequest",
{
"message_id": fields.String(description="Message ID (optional)"),
"question": fields.String(description="Question text (required when message_id not provided)"),
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
"content": fields.String(description="Content text (use 'answer' or 'content')"),
"annotation_reply": fields.Raw(description="Annotation reply data"),
},
)
)
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) @console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -195,16 +224,9 @@ class AnnotationApi(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_id): def post(self, app_id):
app_id = str(app_id) app_id = str(app_id)
parser = ( args = CreateAnnotationPayload.model_validate(console_ns.payload)
reqparse.RequestParser() data = args.model_dump(exclude_none=True)
.add_argument("message_id", required=False, type=uuid_value, location="json") annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
.add_argument("question", required=False, type=str, location="json")
.add_argument("answer", required=False, type=str, location="json")
.add_argument("content", required=False, type=str, location="json")
.add_argument("annotation_reply", required=False, type=dict, location="json")
)
args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
return annotation return annotation
@setup_required @setup_required
@ -237,7 +259,7 @@ class AnnotationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/export") @console_ns.route("/apps/<uuid:app_id>/annotations/export")
class AnnotationExportApi(Resource): class AnnotationExportApi(Resource):
@console_ns.doc("export_annotations") @console_ns.doc("export_annotations")
@console_ns.doc(description="Export all annotations for an app") @console_ns.doc(description="Export all annotations for an app with CSV injection protection")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response( @console_ns.response(
200, 200,
@ -252,15 +274,14 @@ class AnnotationExportApi(Resource):
def get(self, app_id): def get(self, app_id):
app_id = str(app_id) app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)} response_data = {"data": marshal(annotation_list, annotation_fields)}
return response, 200
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
response.headers["Content-Type"] = "application/json; charset=utf-8"
response.headers["X-Content-Type-Options"] = "nosniff"
parser = ( return response
reqparse.RequestParser()
.add_argument("question", required=True, type=str, location="json")
.add_argument("answer", required=True, type=str, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>") @console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
@ -271,7 +292,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) @console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(204, "Annotation deleted successfully") @console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.expect(parser) @console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -281,8 +302,10 @@ class AnnotationUpdateDeleteApi(Resource):
def post(self, app_id, annotation_id): def post(self, app_id, annotation_id):
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
args = parser.parse_args() args = UpdateAnnotationPayload.model_validate(console_ns.payload)
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return annotation return annotation
@setup_required @setup_required
@ -299,18 +322,25 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import") @console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
class AnnotationBatchImportApi(Resource): class AnnotationBatchImportApi(Resource):
@console_ns.doc("batch_import_annotations") @console_ns.doc("batch_import_annotations")
@console_ns.doc(description="Batch import annotations from CSV file") @console_ns.doc(description="Batch import annotations from CSV file with rate limiting and security checks")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Batch import started successfully") @console_ns.response(200, "Batch import started successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "No file uploaded or too many files") @console_ns.response(400, "No file uploaded or too many files")
@console_ns.response(413, "File too large")
@console_ns.response(429, "Too many requests or concurrent imports")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check("annotation")
@annotation_import_rate_limit
@annotation_import_concurrency_limit
@edit_permission_required @edit_permission_required
def post(self, app_id): def post(self, app_id):
from configs import dify_config
app_id = str(app_id) app_id = str(app_id)
# check file # check file
if "file" not in request.files: if "file" not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()
@ -320,9 +350,27 @@ class AnnotationBatchImportApi(Resource):
# get file from request # get file from request
file = request.files["file"] file = request.files["file"]
# check file type # check file type
if not file.filename or not file.filename.lower().endswith(".csv"): if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed") raise ValueError("Invalid file type. Only CSV files are allowed")
# Check file size before processing
file.seek(0, 2) # Seek to end of file
file_size = file.tell()
file.seek(0) # Reset to beginning
max_size_bytes = dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT * 1024 * 1024
if file_size > max_size_bytes:
abort(
413,
f"File size exceeds maximum limit of {dify_config.ANNOTATION_IMPORT_FILE_SIZE_LIMIT}MB. "
f"Please reduce the file size and try again.",
)
if file_size == 0:
raise ValueError("The uploaded file is empty")
return AppAnnotationService.batch_import_app_annotations(app_id, file) return AppAnnotationService.batch_import_app_annotations(app_id, file)

View File

@ -1,9 +1,12 @@
import uuid import uuid
from typing import Literal
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, abort from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -28,7 +31,6 @@ from fields.app_fields import (
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
from libs.helper import AppIconUrlField, TimestampField from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import App, Workflow from models import App, Workflow
from services.app_dsl_service import AppDslService, ImportMode from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
@ -36,6 +38,116 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
if not value:
return None
if isinstance(value, str):
items = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(value, list):
items = [str(item).strip() for item in value if item and str(item).strip()]
else:
raise TypeError("Unsupported tag_ids type.")
if not items:
return None
try:
return [str(uuid.UUID(item)) for item in items]
except ValueError as exc:
raise ValueError("Invalid UUID format in tag_ids.") from exc
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app", max_length=400)
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
class AppNamePayload(BaseModel):
name: str = Field(..., min_length=1, description="Name to check")
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
icon_background: str | None = Field(default=None, description="Icon background color")
class AppSiteStatusPayload(BaseModel):
enable_site: bool = Field(..., description="Enable or disable site")
class AppApiStatusPayload(BaseModel):
enable_api: bool = Field(..., description="Enable or disable API")
class AppTracePayload(BaseModel):
enabled: bool = Field(..., description="Enable or disable tracing")
tracing_provider: str | None = Field(default=None, description="Tracing provider")
@field_validator("tracing_provider")
@classmethod
def validate_tracing_provider(cls, value: str | None, info) -> str | None:
if info.data.get("enabled") and not value:
raise ValueError("tracing_provider is required when enabled is True")
return value
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AppListQuery)
reg(CreateAppPayload)
reg(UpdateAppPayload)
reg(CopyAppPayload)
reg(AppExportQuery)
reg(AppNamePayload)
reg(AppIconPayload)
reg(AppSiteStatusPayload)
reg(AppApiStatusPayload)
reg(AppTracePayload)
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register base models first # Register base models first
@ -147,22 +259,7 @@ app_pagination_model = console_ns.model(
class AppListApi(Resource): class AppListApi(Resource):
@console_ns.doc("list_apps") @console_ns.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering") @console_ns.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect( @console_ns.expect(console_ns.models[AppListQuery.__name__])
console_ns.parser()
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
.add_argument(
"mode",
type=str,
location="args",
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
default="all",
help="App mode filter",
)
.add_argument("name", type=str, location="args", help="Filter by app name")
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
)
@console_ns.response(200, "Success", app_pagination_model) @console_ns.response(200, "Success", app_pagination_model)
@setup_required @setup_required
@login_required @login_required
@ -172,42 +269,12 @@ class AppListApi(Resource):
"""Get app list""" """Get app list"""
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
def uuid_list(value): args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try: args_dict = args.model_dump()
return [str(uuid.UUID(v)) for v in value.split(",")]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"mode",
type=str,
choices=[
"completion",
"chat",
"advanced-chat",
"workflow",
"agent-chat",
"channel",
"all",
],
default="all",
location="args",
required=False,
)
.add_argument("name", type=str, location="args", required=False)
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
)
args = parser.parse_args()
# get app list # get app list
app_service = AppService() app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args) app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@ -242,10 +309,13 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN, NodeType.TRIGGER_PLUGIN,
} }
for workflow in draft_workflows: for workflow in draft_workflows:
for _, node_data in workflow.walk_nodes(): try:
if node_data.get("type") in trigger_node_types: for _, node_data in workflow.walk_nodes():
draft_trigger_app_ids.add(str(workflow.app_id)) if node_data.get("type") in trigger_node_types:
break draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
continue
for app in app_pagination.items: for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
@ -254,19 +324,7 @@ class AppListApi(Resource):
@console_ns.doc("create_app") @console_ns.doc("create_app")
@console_ns.doc(description="Create a new application") @console_ns.doc(description="Create a new application")
@console_ns.expect( @console_ns.expect(console_ns.models[CreateAppPayload.__name__])
console_ns.model(
"CreateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(201, "App created successfully", app_detail_model) @console_ns.response(201, "App created successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@ -279,22 +337,10 @@ class AppListApi(Resource):
def post(self): def post(self):
"""Create app""" """Create app"""
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( args = CreateAppPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
app_service = AppService() app_service = AppService()
app = app_service.create_app(current_tenant_id, args, current_user) app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
return app, 201 return app, 201
@ -326,20 +372,7 @@ class AppApi(Resource):
@console_ns.doc("update_app") @console_ns.doc("update_app")
@console_ns.doc(description="Update application details") @console_ns.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
console_ns.model(
"UpdateAppRequest",
{
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
"max_active_requests": fields.Integer(description="Maximum active requests"),
},
)
)
@console_ns.response(200, "App updated successfully", app_detail_with_site_model) @console_ns.response(200, "App updated successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@ -351,28 +384,18 @@ class AppApi(Resource):
@marshal_with(app_detail_with_site_model) @marshal_with(app_detail_with_site_model)
def put(self, app_model): def put(self, app_model):
"""Update app""" """Update app"""
parser = ( args = UpdateAppPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
.add_argument("max_active_requests", type=int, location="json")
)
args = parser.parse_args()
app_service = AppService() app_service = AppService()
args_dict: AppService.ArgsDict = { args_dict: AppService.ArgsDict = {
"name": args["name"], "name": args.name,
"description": args.get("description", ""), "description": args.description or "",
"icon_type": args.get("icon_type", ""), "icon_type": args.icon_type or "",
"icon": args.get("icon", ""), "icon": args.icon or "",
"icon_background": args.get("icon_background", ""), "icon_background": args.icon_background or "",
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False), "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
"max_active_requests": args.get("max_active_requests", 0), "max_active_requests": args.max_active_requests or 0,
} }
app_model = app_service.update_app(app_model, args_dict) app_model = app_service.update_app(app_model, args_dict)
@ -401,18 +424,7 @@ class AppCopyApi(Resource):
@console_ns.doc("copy_app") @console_ns.doc("copy_app")
@console_ns.doc(description="Create a copy of an existing application") @console_ns.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"}) @console_ns.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect( @console_ns.expect(console_ns.models[CopyAppPayload.__name__])
console_ns.model(
"CopyAppRequest",
{
"name": fields.String(description="Name for the copied app"),
"description": fields.String(description="Description for the copied app"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(201, "App copied successfully", app_detail_with_site_model) @console_ns.response(201, "App copied successfully", app_detail_with_site_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -426,15 +438,7 @@ class AppCopyApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = CopyAppPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("name", type=str, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
@ -443,11 +447,11 @@ class AppCopyApi(Resource):
account=current_user, account=current_user,
import_mode=ImportMode.YAML_CONTENT, import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content, yaml_content=yaml_content,
name=args.get("name"), name=args.name,
description=args.get("description"), description=args.description,
icon_type=args.get("icon_type"), icon_type=args.icon_type,
icon=args.get("icon"), icon=args.icon,
icon_background=args.get("icon_background"), icon_background=args.icon_background,
) )
session.commit() session.commit()
@ -462,11 +466,7 @@ class AppExportApi(Resource):
@console_ns.doc("export_app") @console_ns.doc("export_app")
@console_ns.doc(description="Export application configuration as DSL") @console_ns.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"}) @console_ns.doc(params={"app_id": "Application ID to export"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppExportQuery.__name__])
console_ns.parser()
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
)
@console_ns.response( @console_ns.response(
200, 200,
"App exported successfully", "App exported successfully",
@ -480,30 +480,23 @@ class AppExportApi(Resource):
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
"""Export app""" """Export app"""
# Add include_secret params args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
parser = (
reqparse.RequestParser()
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
.add_argument("workflow_id", type=str, location="args")
)
args = parser.parse_args()
return { return {
"data": AppDslService.export_dsl( "data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id") app_model=app_model,
include_secret=args.include_secret,
workflow_id=args.workflow_id,
) )
} }
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
@console_ns.route("/apps/<uuid:app_id>/name") @console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource): class AppNameApi(Resource):
@console_ns.doc("check_app_name") @console_ns.doc("check_app_name")
@console_ns.doc(description="Check if app name is available") @console_ns.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[AppNamePayload.__name__])
@console_ns.response(200, "Name availability checked") @console_ns.response(200, "Name availability checked")
@setup_required @setup_required
@login_required @login_required
@ -512,10 +505,10 @@ class AppNameApi(Resource):
@marshal_with(app_detail_model) @marshal_with(app_detail_model)
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
args = parser.parse_args() args = AppNamePayload.model_validate(console_ns.payload)
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_name(app_model, args["name"]) app_model = app_service.update_app_name(app_model, args.name)
return app_model return app_model
@ -525,16 +518,7 @@ class AppIconApi(Resource):
@console_ns.doc("update_app_icon") @console_ns.doc("update_app_icon")
@console_ns.doc(description="Update application icon") @console_ns.doc(description="Update application icon")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppIconPayload.__name__])
console_ns.model(
"AppIconRequest",
{
"icon": fields.String(required=True, description="Icon data"),
"icon_type": fields.String(description="Icon type"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@console_ns.response(200, "Icon updated successfully") @console_ns.response(200, "Icon updated successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -544,15 +528,10 @@ class AppIconApi(Resource):
@marshal_with(app_detail_model) @marshal_with(app_detail_model)
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
parser = ( args = AppIconPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "") app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
return app_model return app_model
@ -562,11 +541,7 @@ class AppSiteStatus(Resource):
@console_ns.doc("update_app_site_status") @console_ns.doc("update_app_site_status")
@console_ns.doc(description="Enable or disable app site") @console_ns.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
console_ns.model(
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
)
)
@console_ns.response(200, "Site status updated successfully", app_detail_model) @console_ns.response(200, "Site status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -576,11 +551,10 @@ class AppSiteStatus(Resource):
@marshal_with(app_detail_model) @marshal_with(app_detail_model)
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json") args = AppSiteStatusPayload.model_validate(console_ns.payload)
args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args["enable_site"]) app_model = app_service.update_app_site_status(app_model, args.enable_site)
return app_model return app_model
@ -590,11 +564,7 @@ class AppApiStatus(Resource):
@console_ns.doc("update_app_api_status") @console_ns.doc("update_app_api_status")
@console_ns.doc(description="Enable or disable app API") @console_ns.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
console_ns.model(
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
)
)
@console_ns.response(200, "API status updated successfully", app_detail_model) @console_ns.response(200, "API status updated successfully", app_detail_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -604,11 +574,10 @@ class AppApiStatus(Resource):
@get_app_model @get_app_model
@marshal_with(app_detail_model) @marshal_with(app_detail_model)
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json") args = AppApiStatusPayload.model_validate(console_ns.payload)
args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args["enable_api"]) app_model = app_service.update_app_api_status(app_model, args.enable_api)
return app_model return app_model
@ -631,15 +600,7 @@ class AppTraceApi(Resource):
@console_ns.doc("update_app_trace") @console_ns.doc("update_app_trace")
@console_ns.doc(description="Update app tracing configuration") @console_ns.doc(description="Update app tracing configuration")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppTracePayload.__name__])
console_ns.model(
"AppTraceRequest",
{
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
"tracing_provider": fields.String(required=True, description="Tracing provider"),
},
)
)
@console_ns.response(200, "Trace configuration updated successfully") @console_ns.response(200, "Trace configuration updated successfully")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -648,17 +609,12 @@ class AppTraceApi(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_id): def post(self, app_id):
# add app trace # add app trace
parser = ( args = AppTracePayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("enabled", type=bool, required=True, location="json")
.add_argument("tracing_provider", type=str, required=True, location="json")
)
args = parser.parse_args()
OpsTraceManager.update_app_tracing_config( OpsTraceManager.update_app_tracing_config(
app_id=app_id, app_id=app_id,
enabled=args["enabled"], enabled=args.enabled,
tracing_provider=args["tracing_provider"], tracing_provider=args.tracing_provider,
) )
return {"result": "success"} return {"result": "success"}

View File

@ -1,4 +1,5 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -35,23 +36,29 @@ app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy "AppImportCheckDependencies", app_import_check_dependencies_fields_copy
) )
parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json") class AppImportPayload(BaseModel):
.add_argument("yaml_url", type=str, location="json") mode: str = Field(..., description="Import mode")
.add_argument("name", type=str, location="json") yaml_content: str | None = None
.add_argument("description", type=str, location="json") yaml_url: str | None = None
.add_argument("icon_type", type=str, location="json") name: str | None = None
.add_argument("icon", type=str, location="json") description: str | None = None
.add_argument("icon_background", type=str, location="json") icon_type: str | None = None
.add_argument("app_id", type=str, location="json") icon: str | None = None
icon_background: str | None = None
app_id: str | None = None
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
) )
@console_ns.route("/apps/imports") @console_ns.route("/apps/imports")
class AppImportApi(Resource): class AppImportApi(Resource):
@console_ns.expect(parser) @console_ns.expect(console_ns.models[AppImportPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -61,7 +68,7 @@ class AppImportApi(Resource):
def post(self): def post(self):
# Check user role first # Check user role first
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser.parse_args() args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
@ -70,15 +77,15 @@ class AppImportApi(Resource):
account = current_user account = current_user
result = import_service.import_app( result = import_service.import_app(
account=account, account=account,
import_mode=args["mode"], import_mode=args.mode,
yaml_content=args.get("yaml_content"), yaml_content=args.yaml_content,
yaml_url=args.get("yaml_url"), yaml_url=args.yaml_url,
name=args.get("name"), name=args.name,
description=args.get("description"), description=args.description,
icon_type=args.get("icon_type"), icon_type=args.icon_type,
icon=args.get("icon"), icon=args.icon,
icon_background=args.get("icon_background"), icon_background=args.icon_background,
app_id=args.get("app_id"), app_id=args.app_id,
) )
session.commit() session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,7 +1,8 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
@ -32,6 +33,27 @@ from services.errors.audio import (
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
text: str = Field(..., description="Text to convert")
voice: str | None = Field(default=None, description="Voice name")
streaming: bool | None = Field(default=None, description="Whether to stream audio")
class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text") @console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@ -92,17 +114,7 @@ class ChatMessageTextApi(Resource):
@console_ns.doc("chat_message_text_to_speech") @console_ns.doc("chat_message_text_to_speech")
@console_ns.doc(description="Convert text to speech for chat messages") @console_ns.doc(description="Convert text to speech for chat messages")
@console_ns.doc(params={"app_id": "App ID"}) @console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TextToSpeechPayload.__name__])
console_ns.model(
"TextToSpeechRequest",
{
"message_id": fields.String(description="Message ID"),
"text": fields.String(required=True, description="Text to convert to speech"),
"voice": fields.String(description="Voice to use for TTS"),
"streaming": fields.Boolean(description="Whether to stream the audio"),
},
)
)
@console_ns.response(200, "Text to speech conversion successful") @console_ns.response(200, "Text to speech conversion successful")
@console_ns.response(400, "Bad request - Invalid parameters") @console_ns.response(400, "Bad request - Invalid parameters")
@get_app_model @get_app_model
@ -111,21 +123,14 @@ class ChatMessageTextApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_model: App): def post(self, app_model: App):
try: try:
parser = ( payload = TextToSpeechPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("message_id", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
response = AudioService.transcript_tts( response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True app_model=app_model,
text=payload.text,
voice=payload.voice,
message_id=payload.message_id,
is_draft=True,
) )
return response return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
@ -159,9 +164,7 @@ class TextModesApi(Resource):
@console_ns.doc("get_text_to_speech_voices") @console_ns.doc("get_text_to_speech_voices")
@console_ns.doc(description="Get available TTS voices for a specific language") @console_ns.doc(description="Get available TTS voices for a specific language")
@console_ns.doc(params={"app_id": "App ID"}) @console_ns.doc(params={"app_id": "App ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TextToSpeechVoiceQuery.__name__])
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
)
@console_ns.response( @console_ns.response(
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")) 200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
) )
@ -172,12 +175,11 @@ class TextModesApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
try: try:
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args") args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
response = AudioService.transcript_tts_voices( response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
language=args["language"], language=args.language,
) )
return response return response

View File

@ -1,7 +1,9 @@
import logging import logging
from typing import Any, Literal
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
@ -35,6 +37,41 @@ from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
model_config_data: dict[str, Any] = Field(..., alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
class CompletionMessagePayload(BaseMessagePayload):
query: str = Field(default="", description="Query text")
class ChatMessagePayload(BaseMessagePayload):
query: str = Field(..., description="User query")
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
console_ns.schema_model(
CompletionMessagePayload.__name__,
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# define completion message api for user # define completion message api for user
@ -43,19 +80,7 @@ class CompletionMessageApi(Resource):
@console_ns.doc("create_completion_message") @console_ns.doc("create_completion_message")
@console_ns.doc(description="Generate completion message for debugging") @console_ns.doc(description="Generate completion message for debugging")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
console_ns.model(
"CompletionMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(description="Query text", default=""),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@console_ns.response(200, "Completion generated successfully") @console_ns.response(200, "Completion generated successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App not found") @console_ns.response(404, "App not found")
@ -64,18 +89,10 @@ class CompletionMessageApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model): def post(self, app_model):
parser = ( args_model = CompletionMessagePayload.model_validate(console_ns.payload)
reqparse.RequestParser() args = args_model.model_dump(exclude_none=True, by_alias=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False args["auto_generate_name"] = False
try: try:
@ -137,21 +154,7 @@ class ChatMessageApi(Resource):
@console_ns.doc("create_chat_message") @console_ns.doc("create_chat_message")
@console_ns.doc(description="Generate chat message for debugging") @console_ns.doc(description="Generate chat message for debugging")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
console_ns.model(
"ChatMessageRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(required=True, description="User query"),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"conversation_id": fields.String(description="Conversation ID"),
"parent_message_id": fields.String(description="Parent message ID"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@console_ns.response(200, "Chat message generated successfully") @console_ns.response(200, "Chat message generated successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(404, "App or conversation not found") @console_ns.response(404, "App or conversation not found")
@ -161,20 +164,10 @@ class ChatMessageApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
parser = ( args_model = ChatMessagePayload.model_validate(console_ns.payload)
reqparse.RequestParser() args = args_model.model_dump(exclude_none=True, by_alias=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args_model.response_mode != "blocking"
args["auto_generate_name"] = False args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)

View File

@ -1,7 +1,9 @@
from typing import Literal
import sqlalchemy as sa import sqlalchemy as sa
from flask import abort from flask import abort, request
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with
from flask_restx.inputs import int_range from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -14,13 +16,53 @@ from extensions.ext_database import db
from fields.conversation_fields import MessageTextField from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import DatetimeString, TimestampField from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode from models.model import AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseConversationQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword")
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
default="all", description="Annotation status filter"
)
page: int = Field(default=1, ge=1, le=99999, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
class CompletionConversationQuery(BaseConversationQuery):
pass
class ChatConversationQuery(BaseConversationQuery):
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort field and direction"
)
console_ns.schema_model(
CompletionConversationQuery.__name__,
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatConversationQuery.__name__,
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models # Register in dependency order: base models first, then dependent models
@ -283,22 +325,7 @@ class CompletionConversationApi(Resource):
@console_ns.doc("list_completion_conversations") @console_ns.doc("list_completion_conversations")
@console_ns.doc(description="Get completion conversations with pagination and filtering") @console_ns.doc(description="Get completion conversations with pagination and filtering")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
)
@console_ns.response(200, "Success", conversation_pagination_model) @console_ns.response(200, "Success", conversation_pagination_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -309,32 +336,17 @@ class CompletionConversationApi(Resource):
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
query = sa.select(Conversation).where( query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
) )
if args["keyword"]: if args.keyword:
query = query.join(Message, Message.conversation_id == Conversation.id).where( query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_( or_(
Message.query.ilike(f"%{args['keyword']}%"), Message.query.ilike(f"%{args.keyword}%"),
Message.answer.ilike(f"%{args['keyword']}%"), Message.answer.ilike(f"%{args.keyword}%"),
) )
) )
@ -342,7 +354,7 @@ class CompletionConversationApi(Resource):
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -354,11 +366,11 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc) query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file # FIXME, the type ignore in this file
if args["annotation_status"] == "annotated": if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
) )
elif args["annotation_status"] == "not_annotated": elif args.annotation_status == "not_annotated":
query = ( query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id) .group_by(Conversation.id)
@ -367,7 +379,7 @@ class CompletionConversationApi(Resource):
query = query.order_by(Conversation.created_at.desc()) query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
return conversations return conversations
@ -419,31 +431,7 @@ class ChatConversationApi(Resource):
@console_ns.doc("list_chat_conversations") @console_ns.doc("list_chat_conversations")
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary") @console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
console_ns.parser()
.add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
.add_argument(
"sort_by",
type=str,
location="args",
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
default="-updated_at",
help="Sort field and direction",
)
)
@console_ns.response(200, "Success", conversation_with_summary_pagination_model) @console_ns.response(200, "Success", conversation_with_summary_pagination_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@ -454,31 +442,7 @@ class ChatConversationApi(Resource):
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
)
args = parser.parse_args()
subquery = ( subquery = (
db.session.query( db.session.query(
@ -490,8 +454,8 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args["keyword"]: if args.keyword:
keyword_filter = f"%{args['keyword']}%" keyword_filter = f"%{args.keyword}%"
query = ( query = (
query.join( query.join(
Message, Message,
@ -514,12 +478,12 @@ class ChatConversationApi(Resource):
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
if start_datetime_utc: if start_datetime_utc:
match args["sort_by"]: match args.sort_by:
case "updated_at" | "-updated_at": case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc) query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _: case "created_at" | "-created_at" | _:
@ -527,35 +491,27 @@ class ChatConversationApi(Resource):
if end_datetime_utc: if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59) end_datetime_utc = end_datetime_utc.replace(second=59)
match args["sort_by"]: match args.sort_by:
case "updated_at" | "-updated_at": case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc) query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _: case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc) query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated": if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
) )
elif args["annotation_status"] == "not_annotated": elif args.annotation_status == "not_annotated":
query = ( query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id) .group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0) .having(func.count(MessageAnnotation.id) == 0)
) )
if args["message_count_gte"] and args["message_count_gte"] >= 1:
query = (
query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"])
)
if app_model.mode == AppMode.ADVANCED_CHAT: if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args["sort_by"]: match args.sort_by:
case "created_at": case "created_at":
query = query.order_by(Conversation.created_at.asc()) query = query.order_by(Conversation.created_at.asc())
case "-created_at": case "-created_at":
@ -567,7 +523,7 @@ class ChatConversationApi(Resource):
case _: case _:
query = query.order_by(Conversation.created_at.desc()) query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
return conversations return conversations

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -14,6 +16,18 @@ from libs.login import login_required
from models import ConversationVariable from models import ConversationVariable
from models.model import AppMode from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first # Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields) conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
@ -33,11 +47,7 @@ class ConversationVariablesApi(Resource):
@console_ns.doc("get_conversation_variables") @console_ns.doc("get_conversation_variables")
@console_ns.doc(description="Get conversation variables for an application") @console_ns.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
console_ns.parser().add_argument(
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
)
)
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model) @console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
@setup_required @setup_required
@login_required @login_required
@ -45,18 +55,14 @@ class ConversationVariablesApi(Resource):
@get_app_model(mode=AppMode.ADVANCED_CHAT) @get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model) @marshal_with(paginated_conversation_variable_model)
def get(self, app_model): def get(self, app_model):
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args") args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
stmt = ( stmt = (
select(ConversationVariable) select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id) .where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at) .order_by(ConversationVariable.created_at)
) )
if args["conversation_id"]: stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
else:
raise ValueError("conversation_id is required")
# NOTE: This is a temporary solution to avoid performance issues. # NOTE: This is a temporary solution to avoid performance issues.
page = 1 page = 1

View File

@ -1,6 +1,8 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any
from flask_restx import Resource, fields, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
@ -21,21 +23,54 @@ from libs.login import current_account_with_tenant, login_required
from models import App from models import App
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(RuleGeneratePayload)
reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
@console_ns.route("/rule-generate") @console_ns.route("/rule-generate")
class RuleGenerateApi(Resource): class RuleGenerateApi(Resource):
@console_ns.doc("generate_rule_config") @console_ns.doc("generate_rule_config")
@console_ns.doc(description="Generate rule configuration using LLM") @console_ns.doc(description="Generate rule configuration using LLM")
@console_ns.expect( @console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
console_ns.model(
"RuleGenerateRequest",
{
"instruction": fields.String(required=True, description="Rule generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
},
)
)
@console_ns.response(200, "Rule configuration generated successfully") @console_ns.response(200, "Rule configuration generated successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@ -43,21 +78,15 @@ class RuleGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( args = RuleGeneratePayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
try: try:
rules = LLMGenerator.generate_rule_config( rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
no_variable=args["no_variable"], no_variable=args.no_variable,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -75,19 +104,7 @@ class RuleGenerateApi(Resource):
class RuleCodeGenerateApi(Resource): class RuleCodeGenerateApi(Resource):
@console_ns.doc("generate_rule_code") @console_ns.doc("generate_rule_code")
@console_ns.doc(description="Generate code rules using LLM") @console_ns.doc(description="Generate code rules using LLM")
@console_ns.expect( @console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
console_ns.model(
"RuleCodeGenerateRequest",
{
"instruction": fields.String(required=True, description="Code generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
"code_language": fields.String(
default="javascript", description="Programming language for code generation"
),
},
)
)
@console_ns.response(200, "Code rules generated successfully") @console_ns.response(200, "Code rules generated successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@ -95,22 +112,15 @@ class RuleCodeGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
try: try:
code_result = LLMGenerator.generate_code( code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
code_language=args["code_language"], code_language=args.code_language,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -128,15 +138,7 @@ class RuleCodeGenerateApi(Resource):
class RuleStructuredOutputGenerateApi(Resource): class RuleStructuredOutputGenerateApi(Resource):
@console_ns.doc("generate_structured_output") @console_ns.doc("generate_structured_output")
@console_ns.doc(description="Generate structured output rules using LLM") @console_ns.doc(description="Generate structured output rules using LLM")
@console_ns.expect( @console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
console_ns.model(
"StructuredOutputGenerateRequest",
{
"instruction": fields.String(required=True, description="Structured output generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
},
)
)
@console_ns.response(200, "Structured output generated successfully") @console_ns.response(200, "Structured output generated successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@ -144,19 +146,14 @@ class RuleStructuredOutputGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
try: try:
structured_output = LLMGenerator.generate_structured_output( structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -174,20 +171,7 @@ class RuleStructuredOutputGenerateApi(Resource):
class InstructionGenerateApi(Resource): class InstructionGenerateApi(Resource):
@console_ns.doc("generate_instruction") @console_ns.doc("generate_instruction")
@console_ns.doc(description="Generate instruction for workflow nodes or general use") @console_ns.doc(description="Generate instruction for workflow nodes or general use")
@console_ns.expect( @console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
console_ns.model(
"InstructionGenerateRequest",
{
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
"node_id": fields.String(description="Node ID for workflow context"),
"current": fields.String(description="Current instruction text"),
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
"instruction": fields.String(required=True, description="Instruction for generation"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@console_ns.response(200, "Instruction generated successfully") @console_ns.response(200, "Instruction generated successfully")
@console_ns.response(400, "Invalid request parameters or flow/workflow not found") @console_ns.response(400, "Invalid request parameters or flow/workflow not found")
@console_ns.response(402, "Provider quota exceeded") @console_ns.response(402, "Provider quota exceeded")
@ -195,79 +179,69 @@ class InstructionGenerateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( args = InstructionGeneratePayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("flow_id", type=str, required=True, default="", location="json")
.add_argument("node_id", type=str, required=False, default="", location="json")
.add_argument("current", type=str, required=False, default="", location="json")
.add_argument("language", type=str, required=False, default="javascript", location="json")
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("ideal_output", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next( code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args["language"])), None (p for p in providers if p.is_accept_language(args.language)), None
) )
code_template = code_provider.get_default_code() if code_provider else "" code_template = code_provider.get_default_code() if code_provider else ""
try: try:
# Generate from nothing for a workflow node # Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "": if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args["flow_id"]).first() app = db.session.query(App).where(App.id == args.flow_id).first()
if not app: if not app:
return {"error": f"app {args['flow_id']} not found"}, 400 return {"error": f"app {args.flow_id} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app) workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow: if not workflow:
return {"error": f"workflow {args['flow_id']} not found"}, 400 return {"error": f"workflow {args.flow_id} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"] nodes: Sequence = workflow.graph_dict["nodes"]
node = [node for node in nodes if node["id"] == args["node_id"]] node = [node for node in nodes if node["id"] == args.node_id]
if len(node) == 0: if len(node) == 0:
return {"error": f"node {args['node_id']} not found"}, 400 return {"error": f"node {args.node_id} not found"}, 400
node_type = node[0]["data"]["type"] node_type = node[0]["data"]["type"]
match node_type: match node_type:
case "llm": case "llm":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_tenant_id, current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
no_variable=True, no_variable=True,
) )
case "agent": case "agent":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_tenant_id, current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
no_variable=True, no_variable=True,
) )
case "code": case "code":
return LLMGenerator.generate_code( return LLMGenerator.generate_code(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
code_language=args["language"], code_language=args.language,
) )
case _: case _:
return {"error": f"invalid node type: {node_type}"} return {"error": f"invalid node type: {node_type}"}
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow if args.node_id == "" and args.current != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy( return LLMGenerator.instruction_modify_legacy(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
flow_id=args["flow_id"], flow_id=args.flow_id,
current=args["current"], current=args.current,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
ideal_output=args["ideal_output"], ideal_output=args.ideal_output,
) )
if args["node_id"] != "" and args["current"] != "": # For workflow node if args.node_id != "" and args.current != "": # For workflow node
return LLMGenerator.instruction_modify_workflow( return LLMGenerator.instruction_modify_workflow(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
flow_id=args["flow_id"], flow_id=args.flow_id,
node_id=args["node_id"], node_id=args.node_id,
current=args["current"], current=args.current,
instruction=args["instruction"], instruction=args.instruction,
model_config=args["model_config"], model_config=args.model_config_data,
ideal_output=args["ideal_output"], ideal_output=args.ideal_output,
workflow_service=WorkflowService(), workflow_service=WorkflowService(),
) )
return {"error": "incompatible parameters"}, 400 return {"error": "incompatible parameters"}, 400
@ -285,24 +259,15 @@ class InstructionGenerateApi(Resource):
class InstructionGenerationTemplateApi(Resource): class InstructionGenerationTemplateApi(Resource):
@console_ns.doc("get_instruction_template") @console_ns.doc("get_instruction_template")
@console_ns.doc(description="Get instruction generation template") @console_ns.doc(description="Get instruction generation template")
@console_ns.expect( @console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
console_ns.model(
"InstructionTemplateRequest",
{
"instruction": fields.String(required=True, description="Template instruction"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@console_ns.response(200, "Template retrieved successfully") @console_ns.response(200, "Template retrieved successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json") args = InstructionTemplatePayload.model_validate(console_ns.payload)
args = parser.parse_args() match args.type:
match args["type"]:
case "prompt": case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
@ -312,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE} return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _: case _:
raise ValueError(f"Invalid type: {args['type']}") raise ValueError(f"Invalid type: {args.type}")

View File

@ -1,7 +1,8 @@
import json import json
from enum import StrEnum from enum import StrEnum
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.console import console_ns
@ -12,6 +13,8 @@ from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register model for flask_restx to avoid dict type issues in Swagger # Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields) app_server_model = console_ns.model("AppServer", app_server_fields)
@ -21,6 +24,22 @@ class AppMCPServerStatus(StrEnum):
INACTIVE = "inactive" INACTIVE = "inactive"
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/apps/<uuid:app_id>/server") @console_ns.route("/apps/<uuid:app_id>/server")
class AppMCPServerController(Resource): class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server") @console_ns.doc("get_app_mcp_server")
@ -39,15 +58,7 @@ class AppMCPServerController(Resource):
@console_ns.doc("create_app_mcp_server") @console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application") @console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
console_ns.model(
"MCPServerCreateRequest",
{
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
},
)
)
@console_ns.response(201, "MCP server configuration created successfully", app_server_model) @console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@account_initialization_required @account_initialization_required
@ -58,21 +69,16 @@ class AppMCPServerController(Resource):
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = ( payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
)
args = parser.parse_args()
description = args.get("description") description = payload.description
if not description: if not description:
description = app_model.description or "" description = app_model.description or ""
server = AppMCPServer( server = AppMCPServer(
name=app_model.name, name=app_model.name,
description=description, description=description,
parameters=json.dumps(args["parameters"], ensure_ascii=False), parameters=json.dumps(payload.parameters, ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE, status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id, app_id=app_model.id,
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
@ -85,17 +91,7 @@ class AppMCPServerController(Resource):
@console_ns.doc("update_app_mcp_server") @console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application") @console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
console_ns.model(
"MCPServerUpdateRequest",
{
"id": fields.String(required=True, description="Server ID"),
"description": fields.String(description="Server description"),
"parameters": fields.Raw(required=True, description="Server parameters configuration"),
"status": fields.String(description="Server status"),
},
)
)
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model) @console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found") @console_ns.response(404, "Server not found")
@ -106,19 +102,12 @@ class AppMCPServerController(Resource):
@marshal_with(app_server_model) @marshal_with(app_server_model)
@edit_permission_required @edit_permission_required
def put(self, app_model): def put(self, app_model):
parser = ( payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() server = db.session.query(AppMCPServer).where(AppMCPServer.id == payload.id).first()
.add_argument("id", type=str, required=True, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("parameters", type=dict, required=True, location="json")
.add_argument("status", type=str, required=False, location="json")
)
args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server: if not server:
raise NotFound() raise NotFound()
description = args.get("description") description = payload.description
if description is None: if description is None:
pass pass
elif not description: elif not description:
@ -126,11 +115,11 @@ class AppMCPServerController(Resource):
else: else:
server.description = description server.description = description
server.parameters = json.dumps(args["parameters"], ensure_ascii=False) server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if args["status"]: if payload.status:
if args["status"] not in [status.value for status in AppMCPServerStatus]: if payload.status not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status") raise ValueError("Invalid status")
server.status = args["status"] server.status = payload.status
db.session.commit() db.session.commit()
return server return server

View File

@ -1,7 +1,9 @@
import logging import logging
from typing import Literal
from flask_restx import Resource, fields, marshal_with, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -33,6 +35,68 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService from services.message_service import MessageService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("first_id", mode="before")
@classmethod
def empty_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
class FeedbackExportQuery(BaseModel):
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
@field_validator("has_comment", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool | None:
if isinstance(value, bool) or value is None:
return value
lowered = value.lower()
if lowered in {"true", "1", "yes", "on"}:
return True
if lowered in {"false", "0", "no", "off"}:
return False
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models # Register in dependency order: base models first, then dependent models
@ -157,12 +221,7 @@ class ChatMessageListApi(Resource):
@console_ns.doc("list_chat_messages") @console_ns.doc("list_chat_messages")
@console_ns.doc(description="Get chat messages for a conversation with pagination") @console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
console_ns.parser()
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
)
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model) @console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(404, "Conversation not found") @console_ns.response(404, "Conversation not found")
@login_required @login_required
@ -172,27 +231,21 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model) @marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
parser = ( args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) .where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
if args["first_id"]: if args.first_id:
first_message = ( first_message = (
db.session.query(Message) db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"]) .where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first() .first()
) )
@ -207,7 +260,7 @@ class ChatMessageListApi(Resource):
Message.id != first_message.id, Message.id != first_message.id,
) )
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args["limit"]) .limit(args.limit)
.all() .all()
) )
else: else:
@ -215,12 +268,12 @@ class ChatMessageListApi(Resource):
db.session.query(Message) db.session.query(Message)
.where(Message.conversation_id == conversation.id) .where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args["limit"]) .limit(args.limit)
.all() .all()
) )
# Initialize has_more based on whether we have a full page # Initialize has_more based on whether we have a full page
if len(history_messages) == args["limit"]: if len(history_messages) == args.limit:
current_page_first_message = history_messages[-1] current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page # Check if there are more messages before the current page
has_more = db.session.scalar( has_more = db.session.scalar(
@ -238,7 +291,7 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages)) history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@console_ns.route("/apps/<uuid:app_id>/feedbacks") @console_ns.route("/apps/<uuid:app_id>/feedbacks")
@ -246,15 +299,7 @@ class MessageFeedbackApi(Resource):
@console_ns.doc("create_message_feedback") @console_ns.doc("create_message_feedback")
@console_ns.doc(description="Create or update message feedback (like/dislike)") @console_ns.doc(description="Create or update message feedback (like/dislike)")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
console_ns.model(
"MessageFeedbackRequest",
{
"message_id": fields.String(required=True, description="Message ID"),
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
},
)
)
@console_ns.response(200, "Feedback updated successfully") @console_ns.response(200, "Feedback updated successfully")
@console_ns.response(404, "Message not found") @console_ns.response(404, "Message not found")
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@ -265,14 +310,9 @@ class MessageFeedbackApi(Resource):
def post(self, app_model): def post(self, app_model):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = MessageFeedbackPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args()
message_id = str(args["message_id"]) message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
@ -281,18 +321,23 @@ class MessageFeedbackApi(Resource):
feedback = message.admin_feedback feedback = message.admin_feedback
if not args["rating"] and feedback: if not args.rating and feedback:
db.session.delete(feedback) db.session.delete(feedback)
elif args["rating"] and feedback: elif args.rating and feedback:
feedback.rating = args["rating"] feedback.rating = args.rating
elif not args["rating"] and not feedback: feedback.content = args.content
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists") raise ValueError("rating cannot be None when feedback not exists")
else: else:
rating_value = args.rating
if rating_value is None:
raise ValueError("rating is required to create feedback")
feedback = MessageFeedback( feedback = MessageFeedback(
app_id=app_model.id, app_id=app_model.id,
conversation_id=message.conversation_id, conversation_id=message.conversation_id,
message_id=message.id, message_id=message.id,
rating=args["rating"], rating=rating_value,
content=args.content,
from_source="admin", from_source="admin",
from_account_id=current_user.id, from_account_id=current_user.id,
) )
@ -369,24 +414,12 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions} return {"data": questions}
# Shared parser for feedback export (used for both documentation and runtime parsing)
feedback_export_parser = (
console_ns.parser()
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
)
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export") @console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
class MessageFeedbackExportApi(Resource): class MessageFeedbackExportApi(Resource):
@console_ns.doc("export_feedbacks") @console_ns.doc("export_feedbacks")
@console_ns.doc(description="Export user feedback data for Google Sheets") @console_ns.doc(description="Export user feedback data for Google Sheets")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(feedback_export_parser) @console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
@console_ns.response(200, "Feedback data exported successfully") @console_ns.response(200, "Feedback data exported successfully")
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@console_ns.response(500, "Internal server error") @console_ns.response(500, "Internal server error")
@ -395,7 +428,7 @@ class MessageFeedbackExportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
args = feedback_export_parser.parse_args() args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# Import the service function # Import the service function
from services.feedback_service import FeedbackService from services.feedback_service import FeedbackService
@ -403,12 +436,12 @@ class MessageFeedbackExportApi(Resource):
try: try:
export_data = FeedbackService.export_feedbacks( export_data = FeedbackService.export_feedbacks(
app_id=app_model.id, app_id=app_model.id,
from_source=args.get("from_source"), from_source=args.from_source,
rating=args.get("rating"), rating=args.rating,
has_comment=args.get("has_comment"), has_comment=args.has_comment,
start_date=args.get("start_date"), start_date=args.start_date,
end_date=args.get("end_date"), end_date=args.end_date,
format_type=args.get("format", "csv"), format_type=args.format,
) )
return export_data return export_data

View File

@ -1,4 +1,8 @@
from flask_restx import Resource, fields, reqparse from typing import Any
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
@ -7,6 +11,26 @@ from controllers.console.wraps import account_initialization_required, setup_req
from libs.login import login_required from libs.login import login_required
from services.ops_service import OpsService from services.ops_service import OpsService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TraceProviderQuery(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
class TraceConfigPayload(BaseModel):
tracing_provider: str = Field(..., description="Tracing provider name")
tracing_config: dict[str, Any] = Field(..., description="Tracing configuration data")
console_ns.schema_model(
TraceProviderQuery.__name__,
TraceProviderQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
TraceConfigPayload.__name__, TraceConfigPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/trace-config") @console_ns.route("/apps/<uuid:app_id>/trace-config")
class TraceAppConfigApi(Resource): class TraceAppConfigApi(Resource):
@ -17,11 +41,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("get_trace_app_config") @console_ns.doc("get_trace_app_config")
@console_ns.doc(description="Get tracing configuration for an application") @console_ns.doc(description="Get tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.response( @console_ns.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data") 200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
) )
@ -30,11 +50,10 @@ class TraceAppConfigApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
try: try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not trace_config: if not trace_config:
return {"has_not_configured": True} return {"has_not_configured": True}
return trace_config return trace_config
@ -44,15 +63,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("create_trace_app_config") @console_ns.doc("create_trace_app_config")
@console_ns.doc(description="Create a new tracing configuration for an application") @console_ns.doc(description="Create a new tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
console_ns.model(
"TraceConfigCreateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Tracing configuration data"),
},
)
)
@console_ns.response( @console_ns.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data") 201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
) )
@ -62,16 +73,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, app_id): def post(self, app_id):
"""Create a new trace app configuration""" """Create a new trace app configuration"""
parser = ( args = TraceConfigPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
try: try:
result = OpsService.create_tracing_app_config( result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
) )
if not result: if not result:
raise TracingConfigIsExist() raise TracingConfigIsExist()
@ -84,15 +90,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("update_trace_app_config") @console_ns.doc("update_trace_app_config")
@console_ns.doc(description="Update an existing tracing configuration for an application") @console_ns.doc(description="Update an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceConfigPayload.__name__])
console_ns.model(
"TraceConfigUpdateRequest",
{
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
"tracing_config": fields.Raw(required=True, description="Updated tracing configuration data"),
},
)
)
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response")) @console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@console_ns.response(400, "Invalid request parameters or configuration not found") @console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required @setup_required
@ -100,16 +98,11 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def patch(self, app_id): def patch(self, app_id):
"""Update an existing trace app configuration""" """Update an existing trace app configuration"""
parser = ( args = TraceConfigPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("tracing_provider", type=str, required=True, location="json")
.add_argument("tracing_config", type=dict, required=True, location="json")
)
args = parser.parse_args()
try: try:
result = OpsService.update_tracing_app_config( result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
) )
if not result: if not result:
raise TracingConfigNotExist() raise TracingConfigNotExist()
@ -120,11 +113,7 @@ class TraceAppConfigApi(Resource):
@console_ns.doc("delete_trace_app_config") @console_ns.doc("delete_trace_app_config")
@console_ns.doc(description="Delete an existing tracing configuration for an application") @console_ns.doc(description="Delete an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[TraceProviderQuery.__name__])
console_ns.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
)
)
@console_ns.response(204, "Tracing configuration deleted successfully") @console_ns.response(204, "Tracing configuration deleted successfully")
@console_ns.response(400, "Invalid request parameters or configuration not found") @console_ns.response(400, "Invalid request parameters or configuration not found")
@setup_required @setup_required
@ -132,11 +121,10 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def delete(self, app_id): def delete(self, app_id):
"""Delete an existing trace app configuration""" """Delete an existing trace app configuration"""
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args") args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
try: try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
if not result: if not result:
raise TracingConfigNotExist() raise TracingConfigNotExist()
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -1,4 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from constants.languages import supported_language from constants.languages import supported_language
@ -16,69 +19,50 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Site from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
icon_type: str | None = Field(default=None)
icon: str | None = Field(default=None)
icon_background: str | None = Field(default=None)
description: str | None = Field(default=None)
default_language: str | None = Field(default=None)
chat_color_theme: str | None = Field(default=None)
chat_color_theme_inverted: bool | None = Field(default=None)
customize_domain: str | None = Field(default=None)
copyright: str | None = Field(default=None)
privacy_policy: str | None = Field(default=None)
custom_disclaimer: str | None = Field(default=None)
customize_token_strategy: Literal["must", "allow", "not_allow"] | None = Field(default=None)
prompt_public: bool | None = Field(default=None)
show_workflow_steps: bool | None = Field(default=None)
use_icon_as_answer_icon: bool | None = Field(default=None)
@field_validator("default_language")
@classmethod
def validate_language(cls, value: str | None) -> str | None:
if value is None:
return value
return supported_language(value)
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register model for flask_restx to avoid dict type issues in Swagger # Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields) app_site_model = console_ns.model("AppSite", app_site_fields)
def parse_app_site_args():
parser = (
reqparse.RequestParser()
.add_argument("title", type=str, required=False, location="json")
.add_argument("icon_type", type=str, required=False, location="json")
.add_argument("icon", type=str, required=False, location="json")
.add_argument("icon_background", type=str, required=False, location="json")
.add_argument("description", type=str, required=False, location="json")
.add_argument("default_language", type=supported_language, required=False, location="json")
.add_argument("chat_color_theme", type=str, required=False, location="json")
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
.add_argument("customize_domain", type=str, required=False, location="json")
.add_argument("copyright", type=str, required=False, location="json")
.add_argument("privacy_policy", type=str, required=False, location="json")
.add_argument("custom_disclaimer", type=str, required=False, location="json")
.add_argument(
"customize_token_strategy",
type=str,
choices=["must", "allow", "not_allow"],
required=False,
location="json",
)
.add_argument("prompt_public", type=bool, required=False, location="json")
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
)
return parser.parse_args()
@console_ns.route("/apps/<uuid:app_id>/site") @console_ns.route("/apps/<uuid:app_id>/site")
class AppSite(Resource): class AppSite(Resource):
@console_ns.doc("update_app_site") @console_ns.doc("update_app_site")
@console_ns.doc(description="Update application site configuration") @console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
console_ns.model(
"AppSiteRequest",
{
"title": fields.String(description="Site title"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"description": fields.String(description="Site description"),
"default_language": fields.String(description="Default language"),
"chat_color_theme": fields.String(description="Chat color theme"),
"chat_color_theme_inverted": fields.Boolean(description="Inverted chat color theme"),
"customize_domain": fields.String(description="Custom domain"),
"copyright": fields.String(description="Copyright text"),
"privacy_policy": fields.String(description="Privacy policy"),
"custom_disclaimer": fields.String(description="Custom disclaimer"),
"customize_token_strategy": fields.String(
enum=["must", "allow", "not_allow"], description="Token strategy"
),
"prompt_public": fields.Boolean(description="Make prompt public"),
"show_workflow_steps": fields.Boolean(description="Show workflow steps"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
},
)
)
@console_ns.response(200, "Site configuration updated successfully", app_site_model) @console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions") @console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found") @console_ns.response(404, "App not found")
@ -89,7 +73,7 @@ class AppSite(Resource):
@get_app_model @get_app_model
@marshal_with(app_site_model) @marshal_with(app_site_model)
def post(self, app_model): def post(self, app_model):
args = parse_app_site_args() args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:
@ -113,7 +97,7 @@ class AppSite(Resource):
"show_workflow_steps", "show_workflow_steps",
"use_icon_as_answer_icon", "use_icon_as_answer_icon",
]: ]:
value = args.get(attr_name) value = getattr(args, attr_name)
if value is not None: if value is not None:
setattr(site, attr_name, value) setattr(site, attr_name, value)

View File

@ -1,8 +1,9 @@
from decimal import Decimal from decimal import Decimal
import sqlalchemy as sa import sqlalchemy as sa
from flask import abort, jsonify from flask import abort, jsonify, request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -10,21 +11,37 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import parse_time_range from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString, convert_datetime_to_date from libs.helper import convert_datetime_to_date
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import AppMode from models import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class StatisticTimeRangeQuery(BaseModel):
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def empty_string_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
StatisticTimeRangeQuery.__name__,
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
class DailyMessageStatistic(Resource): class DailyMessageStatistic(Resource):
@console_ns.doc("get_daily_message_statistics") @console_ns.doc("get_daily_message_statistics")
@console_ns.doc(description="Get daily message statistics for an application") @console_ns.doc(description="Get daily message statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
console_ns.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.response( @console_ns.response(
200, 200,
"Daily message statistics retrieved successfully", "Daily message statistics retrieved successfully",
@ -37,12 +54,7 @@ class DailyMessageStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
parser = ( args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -57,7 +69,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -81,19 +93,12 @@ WHERE
return jsonify({"data": response_data}) return jsonify({"data": response_data})
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource): class DailyConversationStatistic(Resource):
@console_ns.doc("get_daily_conversation_statistics") @console_ns.doc("get_daily_conversation_statistics")
@console_ns.doc(description="Get daily conversation statistics for an application") @console_ns.doc(description="Get daily conversation statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Daily conversation statistics retrieved successfully", "Daily conversation statistics retrieved successfully",
@ -106,7 +111,7 @@ class DailyConversationStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -121,7 +126,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -149,7 +154,7 @@ class DailyTerminalsStatistic(Resource):
@console_ns.doc("get_daily_terminals_statistics") @console_ns.doc("get_daily_terminals_statistics")
@console_ns.doc(description="Get daily terminal/end-user statistics for an application") @console_ns.doc(description="Get daily terminal/end-user statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Daily terminal statistics retrieved successfully", "Daily terminal statistics retrieved successfully",
@ -162,7 +167,7 @@ class DailyTerminalsStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -177,7 +182,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -206,7 +211,7 @@ class DailyTokenCostStatistic(Resource):
@console_ns.doc("get_daily_token_cost_statistics") @console_ns.doc("get_daily_token_cost_statistics")
@console_ns.doc(description="Get daily token cost statistics for an application") @console_ns.doc(description="Get daily token cost statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Daily token cost statistics retrieved successfully", "Daily token cost statistics retrieved successfully",
@ -219,7 +224,7 @@ class DailyTokenCostStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -235,7 +240,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -266,7 +271,7 @@ class AverageSessionInteractionStatistic(Resource):
@console_ns.doc("get_average_session_interaction_statistics") @console_ns.doc("get_average_session_interaction_statistics")
@console_ns.doc(description="Get average session interaction statistics for an application") @console_ns.doc(description="Get average session interaction statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Average session interaction statistics retrieved successfully", "Average session interaction statistics retrieved successfully",
@ -279,7 +284,7 @@ class AverageSessionInteractionStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("c.created_at") converted_created_at = convert_datetime_to_date("c.created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -302,7 +307,7 @@ FROM
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -342,7 +347,7 @@ class UserSatisfactionRateStatistic(Resource):
@console_ns.doc("get_user_satisfaction_rate_statistics") @console_ns.doc("get_user_satisfaction_rate_statistics")
@console_ns.doc(description="Get user satisfaction rate statistics for an application") @console_ns.doc(description="Get user satisfaction rate statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"User satisfaction rate statistics retrieved successfully", "User satisfaction rate statistics retrieved successfully",
@ -355,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("m.created_at") converted_created_at = convert_datetime_to_date("m.created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -374,7 +379,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -408,7 +413,7 @@ class AverageResponseTimeStatistic(Resource):
@console_ns.doc("get_average_response_time_statistics") @console_ns.doc("get_average_response_time_statistics")
@console_ns.doc(description="Get average response time statistics for an application") @console_ns.doc(description="Get average response time statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Average response time statistics retrieved successfully", "Average response time statistics retrieved successfully",
@ -421,7 +426,7 @@ class AverageResponseTimeStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -436,7 +441,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -465,7 +470,7 @@ class TokensPerSecondStatistic(Resource):
@console_ns.doc("get_tokens_per_second_statistics") @console_ns.doc("get_tokens_per_second_statistics")
@console_ns.doc(description="Get tokens per second statistics for an application") @console_ns.doc(description="Get tokens per second statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Tokens per second statistics retrieved successfully", "Tokens per second statistics retrieved successfully",
@ -477,7 +482,7 @@ class TokensPerSecondStatistic(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = parser.parse_args() args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at") converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT sql_query = f"""SELECT
@ -495,7 +500,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))

View File

@ -1,10 +1,11 @@
import json import json
import logging import logging
from collections.abc import Sequence from collections.abc import Sequence
from typing import cast from typing import Any
from flask import abort, request from flask import abort, request
from flask_restx import Resource, fields, inputs, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@ -49,6 +50,7 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LISTENING_RETRY_IN = 2000 LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register models for flask_restx to avoid dict type issues in Swagger # Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models # Register in dependency order: base models first, then dependent models
@ -107,6 +109,104 @@ if workflow_run_node_execution_model is None:
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields) workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
class SyncDraftWorkflowPayload(BaseModel):
graph: dict[str, Any]
features: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
class BaseWorkflowRunPayload(BaseModel):
files: list[dict[str, Any]] | None = None
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any] | None = None
query: str = ""
conversation_id: str | None = None
parent_message_id: str | None = None
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class IterationNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class LoopNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
query: str = ""
class PublishWorkflowPayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class ConvertToWorkflowPayload(BaseModel):
name: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str
class DraftWorkflowTriggerRunAllPayload(BaseModel):
node_ids: list[str]
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(SyncDraftWorkflowPayload)
reg(AdvancedChatWorkflowRunPayload)
reg(IterationNodeRunPayload)
reg(LoopNodeRunPayload)
reg(DraftWorkflowRunPayload)
reg(DraftWorkflowNodeRunPayload)
reg(PublishWorkflowPayload)
reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing # TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
# at the controller level rather than in the workflow logic. This would improve separation # at the controller level rather than in the workflow logic. This would improve separation
# of concerns and make the code more maintainable. # of concerns and make the code more maintainable.
@ -158,18 +258,7 @@ class DraftWorkflowApi(Resource):
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@console_ns.doc("sync_draft_workflow") @console_ns.doc("sync_draft_workflow")
@console_ns.doc(description="Sync draft workflow configuration") @console_ns.doc(description="Sync draft workflow configuration")
@console_ns.expect( @console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
console_ns.model(
"SyncDraftWorkflowRequest",
{
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
"features": fields.Raw(required=True, description="Workflow features configuration"),
"hash": fields.String(description="Workflow hash for validation"),
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Draft workflow synced successfully", "Draft workflow synced successfully",
@ -193,36 +282,23 @@ class DraftWorkflowApi(Resource):
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
payload_data: dict[str, Any] | None = None
if "application/json" in content_type: if "application/json" in content_type:
parser = ( payload_data = request.get_json(silent=True)
reqparse.RequestParser() if not isinstance(payload_data, dict):
.add_argument("graph", type=dict, required=True, nullable=False, location="json") return {"message": "Invalid JSON data"}, 400
.add_argument("features", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
data = json.loads(request.data.decode("utf-8")) payload_data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
raise ValueError("graph or features is not a dict")
args = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
}
except json.JSONDecodeError: except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400 return {"message": "Invalid JSON data"}, 400
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
else: else:
abort(415) abort(415)
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
args = args_model.model_dump()
workflow_service = WorkflowService() workflow_service = WorkflowService()
try: try:
@ -258,17 +334,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@console_ns.doc("run_advanced_chat_draft_workflow") @console_ns.doc("run_advanced_chat_draft_workflow")
@console_ns.doc(description="Run draft workflow for advanced chat application") @console_ns.doc(description="Run draft workflow for advanced chat application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__])
console_ns.model(
"AdvancedChatWorkflowRunRequest",
{
"query": fields.String(required=True, description="User query"),
"inputs": fields.Raw(description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
"conversation_id": fields.String(description="Conversation ID"),
},
)
)
@console_ns.response(200, "Workflow run started successfully") @console_ns.response(200, "Workflow run started successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@ -283,16 +349,8 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = args_model.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, location="json")
.add_argument("query", type=str, required=True, location="json", default="")
.add_argument("files", type=list, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
if external_trace_id: if external_trace_id:
@ -322,15 +380,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_iteration_node") @console_ns.doc("run_advanced_chat_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for advanced chat") @console_ns.doc(description="Run draft workflow iteration node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
console_ns.model(
"IterationNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Iteration node run started successfully") @console_ns.response(200, "Iteration node run started successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found") @console_ns.response(404, "Node not found")
@ -344,8 +394,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node Run draft workflow iteration node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_iteration( response = AppGenerateService.generate_single_iteration(
@ -369,15 +418,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_workflow_draft_iteration_node") @console_ns.doc("run_workflow_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node") @console_ns.doc(description="Run draft workflow iteration node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__])
console_ns.model(
"WorkflowIterationNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Workflow iteration node run started successfully") @console_ns.response(200, "Workflow iteration node run started successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found") @console_ns.response(404, "Node not found")
@ -391,8 +432,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node Run draft workflow iteration node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_iteration( response = AppGenerateService.generate_single_iteration(
@ -416,15 +456,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_loop_node") @console_ns.doc("run_advanced_chat_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for advanced chat") @console_ns.doc(description="Run draft workflow loop node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
console_ns.model(
"LoopNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Loop node run started successfully") @console_ns.response(200, "Loop node run started successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found") @console_ns.response(404, "Node not found")
@ -438,8 +470,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_loop( response = AppGenerateService.generate_single_loop(
@ -463,15 +494,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_workflow_draft_loop_node") @console_ns.doc("run_workflow_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node") @console_ns.doc(description="Run draft workflow loop node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__])
console_ns.model(
"WorkflowLoopNodeRunRequest",
{
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Workflow loop node run started successfully") @console_ns.response(200, "Workflow loop node run started successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found") @console_ns.response(404, "Node not found")
@ -485,8 +508,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json") args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_loop( response = AppGenerateService.generate_single_loop(
@ -510,15 +532,7 @@ class DraftWorkflowRunApi(Resource):
@console_ns.doc("run_draft_workflow") @console_ns.doc("run_draft_workflow")
@console_ns.doc(description="Run draft workflow") @console_ns.doc(description="Run draft workflow")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
console_ns.model(
"DraftWorkflowRunRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
},
)
)
@console_ns.response(200, "Draft workflow run started successfully") @console_ns.response(200, "Draft workflow run started successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@setup_required @setup_required
@ -531,12 +545,7 @@ class DraftWorkflowRunApi(Resource):
Run draft workflow Run draft workflow
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
if external_trace_id: if external_trace_id:
@ -588,14 +597,7 @@ class DraftWorkflowNodeRunApi(Resource):
@console_ns.doc("run_draft_workflow_node") @console_ns.doc("run_draft_workflow_node")
@console_ns.doc(description="Run draft workflow node") @console_ns.doc(description="Run draft workflow node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__])
console_ns.model(
"DraftWorkflowNodeRunRequest",
{
"inputs": fields.Raw(description="Input variables"),
},
)
)
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model) @console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model)
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(404, "Node not found") @console_ns.response(404, "Node not found")
@ -610,15 +612,10 @@ class DraftWorkflowNodeRunApi(Resource):
Run draft workflow node Run draft workflow node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = args_model.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("query", type=str, required=False, location="json", default="")
.add_argument("files", type=list, location="json", default=[])
)
args = parser.parse_args()
user_inputs = args.get("inputs") user_inputs = args_model.inputs
if user_inputs is None: if user_inputs is None:
raise ValueError("missing inputs") raise ValueError("missing inputs")
@ -643,13 +640,6 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution return workflow_node_execution
parser_publish = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
@console_ns.route("/apps/<uuid:app_id>/workflows/publish") @console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource): class PublishedWorkflowApi(Resource):
@console_ns.doc("get_published_workflow") @console_ns.doc("get_published_workflow")
@ -674,7 +664,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None # return workflow, if not found, return None
return workflow return workflow
@console_ns.expect(parser_publish) @console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -686,13 +676,7 @@ class PublishedWorkflowApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_publish.parse_args() args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
workflow_service = WorkflowService() workflow_service = WorkflowService()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -741,9 +725,6 @@ class DefaultBlockConfigsApi(Resource):
return workflow_service.get_default_block_configs() return workflow_service.get_default_block_configs()
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource): class DefaultBlockConfigApi(Resource):
@console_ns.doc("get_default_block_config") @console_ns.doc("get_default_block_config")
@ -751,7 +732,7 @@ class DefaultBlockConfigApi(Resource):
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"}) @console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@console_ns.response(200, "Default block configuration retrieved successfully") @console_ns.response(200, "Default block configuration retrieved successfully")
@console_ns.response(404, "Block type not found") @console_ns.response(404, "Block type not found")
@console_ns.expect(parser_block) @console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -761,14 +742,12 @@ class DefaultBlockConfigApi(Resource):
""" """
Get default block config Get default block config
""" """
args = parser_block.parse_args() args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
q = args.get("q")
filters = None filters = None
if q: if args.q:
try: try:
filters = json.loads(args.get("q", "")) filters = json.loads(args.q)
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError("Invalid filters") raise ValueError("Invalid filters")
@ -777,18 +756,9 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters) return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
parser_convert = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow") @console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource): class ConvertToWorkflowApi(Resource):
@console_ns.expect(parser_convert) @console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__])
@console_ns.doc("convert_to_workflow") @console_ns.doc("convert_to_workflow")
@console_ns.doc(description="Convert application to workflow mode") @console_ns.doc(description="Convert application to workflow mode")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@ -808,10 +778,8 @@ class ConvertToWorkflowApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if request.data: payload = console_ns.payload or {}
args = parser_convert.parse_args() args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
else:
args = {}
# convert to workflow mode # convert to workflow mode
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -823,18 +791,9 @@ class ConvertToWorkflowApi(Resource):
} }
parser_workflows = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/apps/<uuid:app_id>/workflows") @console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource): class PublishedAllWorkflowApi(Resource):
@console_ns.expect(parser_workflows) @console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
@console_ns.doc("get_all_published_workflows") @console_ns.doc("get_all_published_workflows")
@console_ns.doc(description="Get all published workflows for an application") @console_ns.doc(description="Get all published workflows for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@ -851,16 +810,15 @@ class PublishedAllWorkflowApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_workflows.parse_args() args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
page = args["page"] page = args.page
limit = args["limit"] limit = args.limit
user_id = args.get("user_id") user_id = args.user_id
named_only = args.get("named_only", False) named_only = args.named_only
if user_id: if user_id:
if user_id != current_user.id: if user_id != current_user.id:
raise Forbidden() raise Forbidden()
user_id = cast(str, user_id)
workflow_service = WorkflowService() workflow_service = WorkflowService()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -886,15 +844,7 @@ class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id") @console_ns.doc("update_workflow_by_id")
@console_ns.doc(description="Update workflow by ID") @console_ns.doc(description="Update workflow by ID")
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"}) @console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
console_ns.model(
"UpdateWorkflowRequest",
{
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@console_ns.response(200, "Workflow updated successfully", workflow_model) @console_ns.response(200, "Workflow updated successfully", workflow_model)
@console_ns.response(404, "Workflow not found") @console_ns.response(404, "Workflow not found")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@ -909,25 +859,14 @@ class WorkflowByIdApi(Resource):
Update workflow attributes Update workflow attributes
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
# Prepare update data # Prepare update data
update_data = {} update_data = {}
if args.get("marked_name") is not None: if args.marked_name is not None:
update_data["marked_name"] = args["marked_name"] update_data["marked_name"] = args.marked_name
if args.get("marked_comment") is not None: if args.marked_comment is not None:
update_data["marked_comment"] = args["marked_comment"] update_data["marked_comment"] = args.marked_comment
if not update_data: if not update_data:
return {"message": "No valid fields to update"}, 400 return {"message": "No valid fields to update"}, 400
@ -1040,11 +979,8 @@ class DraftWorkflowTriggerRunApi(Resource):
Poll for trigger events and execute full workflow when event arrives Poll for trigger events and execute full workflow when event arrives
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument( args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
"node_id", type=str, required=True, location="json", nullable=False node_id = args.node_id
)
args = parser.parse_args()
node_id = args["node_id"]
workflow_service = WorkflowService() workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model) draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow: if not draft_workflow:
@ -1172,14 +1108,7 @@ class DraftWorkflowTriggerRunAllApi(Resource):
@console_ns.doc("draft_workflow_trigger_run_all") @console_ns.doc("draft_workflow_trigger_run_all")
@console_ns.doc(description="Full workflow debug when the start node is a trigger") @console_ns.doc(description="Full workflow debug when the start node is a trigger")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__])
console_ns.model(
"DraftWorkflowTriggerRunAllRequest",
{
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
},
)
)
@console_ns.response(200, "Workflow executed successfully") @console_ns.response(200, "Workflow executed successfully")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@console_ns.response(500, "Internal server error") @console_ns.response(500, "Internal server error")
@ -1194,11 +1123,8 @@ class DraftWorkflowTriggerRunAllApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument( args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
"node_ids", type=list, required=True, location="json", nullable=False node_ids = args.node_ids
)
args = parser.parse_args()
node_ids = args["node_ids"]
workflow_service = WorkflowService() workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model) draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow: if not draft_workflow:

View File

@ -1,6 +1,9 @@
from datetime import datetime
from dateutil.parser import isoparse from dateutil.parser import isoparse
from flask_restx import Resource, marshal_with, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import console_ns from controllers.console import console_ns
@ -14,6 +17,48 @@ from models import App
from models.model import AppMode from models.model import AppMode
from services.workflow_app_service import WorkflowAppService from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
status: WorkflowExecutionStatus | None = Field(
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
)
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
created_by_account: str | None = Field(default=None, description="Filter by account")
detail: bool = Field(default=False, description="Whether to return detailed logs")
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
@field_validator("created_at__before", "created_at__after", mode="before")
@classmethod
def parse_datetime(cls, value: str | None) -> datetime | None:
if value in (None, ""):
return None
return isoparse(value) # type: ignore
@field_validator("detail", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
lowered = value.lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
raise ValueError("Invalid boolean value for detail")
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# Register model for flask_restx to avoid dict type issues in Swagger # Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns) workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
@ -23,19 +68,7 @@ class WorkflowAppLogApi(Resource):
@console_ns.doc("get_workflow_app_logs") @console_ns.doc("get_workflow_app_logs")
@console_ns.doc(description="Get workflow application execution logs") @console_ns.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc( @console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
params={
"keyword": "Search keyword for filtering logs",
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
"created_at__before": "Filter logs created before this timestamp",
"created_at__after": "Filter logs created after this timestamp",
"created_by_end_user_session_id": "Filter by end user session ID",
"created_by_account": "Filter by account",
"detail": "Whether to return detailed logs",
"page": "Page number (1-99999)",
"limit": "Number of items per page (1-100)",
}
)
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model) @console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
@setup_required @setup_required
@login_required @login_required
@ -46,44 +79,7 @@ class WorkflowAppLogApi(Resource):
""" """
Get workflow app logs Get workflow app logs
""" """
parser = ( args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("detail", type=bool, location="args", required=False, default=False)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

View File

@ -1,10 +1,11 @@
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import NoReturn, ParamSpec, TypeVar from typing import Any, NoReturn, ParamSpec, TypeVar
from flask import Response from flask import Response, request
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import console_ns from controllers.console import console_ns
@ -29,6 +30,27 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowDraftVariableListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=100_000, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Items per page")
class WorkflowDraftVariableUpdatePayload(BaseModel):
name: str | None = Field(default=None, description="Variable name")
value: Any | None = Field(default=None, description="Variable value")
console_ns.schema_model(
WorkflowDraftVariableListQuery.__name__,
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
WorkflowDraftVariableUpdatePayload.__name__,
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _convert_values_to_json_serializable_object(value: Segment): def _convert_values_to_json_serializable_object(value: Segment):
@ -57,22 +79,6 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
return _convert_values_to_json_serializable_object(value) return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = (
reqparse.RequestParser()
.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
return parser
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type value_type = workflow_draft_var.value_type
return value_type.exposed_type().value return value_type.exposed_type().value
@ -201,7 +207,7 @@ def _api_prerequisite(f: Callable[P, R]):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource): class WorkflowVariableCollectionApi(Resource):
@console_ns.expect(_create_pagination_parser()) @console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
@console_ns.doc("get_workflow_variables") @console_ns.doc("get_workflow_variables")
@console_ns.doc(description="Get draft workflow variables") @console_ns.doc(description="Get draft workflow variables")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@ -215,8 +221,7 @@ class WorkflowVariableCollectionApi(Resource):
""" """
Get draft workflow Get draft workflow
""" """
parser = _create_pagination_parser() args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -323,15 +328,7 @@ class VariableApi(Resource):
@console_ns.doc("update_variable") @console_ns.doc("update_variable")
@console_ns.doc(description="Update a workflow variable") @console_ns.doc(description="Update a workflow variable")
@console_ns.expect( @console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
console_ns.model(
"UpdateVariableRequest",
{
"name": fields.String(description="Variable name"),
"value": fields.Raw(description="Variable value"),
},
)
)
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model) @console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
@console_ns.response(404, "Variable not found") @console_ns.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
@ -358,16 +355,10 @@ class VariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),
) )
args = parser.parse_args(strict=True) args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable = draft_var_srv.get_variable(variable_id=variable_id) variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None: if variable is None:
@ -375,8 +366,8 @@ class VariableApi(Resource):
if variable.app_id != app_model.id: if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}") raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None) new_name = args_model.name
raw_value = args.get(self._PATCH_VALUE_FIELD, None) raw_value = args_model.value
if new_name is None and raw_value is None: if new_name is None and raw_value is None:
return variable return variable

View File

@ -1,7 +1,8 @@
from typing import cast from typing import Literal, cast
from flask_restx import Resource, fields, marshal_with, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -92,70 +93,51 @@ workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy "WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
) )
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _parse_workflow_run_list_args():
"""
Parse common arguments for workflow run list endpoints.
Returns: class WorkflowRunListQuery(BaseModel):
Parsed arguments containing last_id, limit, status, and triggered_from filters last_id: str | None = Field(default=None, description="Last run ID for pagination")
""" limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
parser = ( status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
reqparse.RequestParser() default=None, description="Workflow run status filter"
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
) )
return parser.parse_args() triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
) )
return parser.parse_args()
@field_validator("last_id")
@classmethod
def validate_last_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class WorkflowRunCountQuery(BaseModel):
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
)
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
)
@field_validator("time_range")
@classmethod
def validate_time_range(cls, value: str | None) -> str | None:
if value is None:
return value
return time_duration(value)
console_ns.schema_model(
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
WorkflowRunCountQuery.__name__,
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
@ -170,6 +152,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.doc( @console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
) )
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model) @console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
@setup_required @setup_required
@login_required @login_required
@ -180,12 +163,13 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
""" """
Get advanced chat app workflow run list Get advanced chat app workflow run list
""" """
args = _parse_workflow_run_list_args() args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified # Default to DEBUGGING if not specified
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from")) WorkflowRunTriggeredFrom(args_model.triggered_from)
if args.get("triggered_from") if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )
@ -217,6 +201,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
) )
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model) @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -226,12 +211,13 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
""" """
Get advanced chat workflow runs count statistics Get advanced chat workflow runs count statistics
""" """
args = _parse_workflow_run_count_args() args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified # Default to DEBUGGING if not specified
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from")) WorkflowRunTriggeredFrom(args_model.triggered_from)
if args.get("triggered_from") if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )
@ -259,6 +245,7 @@ class WorkflowRunListApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
) )
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model) @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -268,12 +255,13 @@ class WorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
args = _parse_workflow_run_list_args() args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility) # Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from")) WorkflowRunTriggeredFrom(args_model.triggered_from)
if args.get("triggered_from") if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )
@ -305,6 +293,7 @@ class WorkflowRunCountApi(Resource):
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
) )
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model) @console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -314,12 +303,13 @@ class WorkflowRunCountApi(Resource):
""" """
Get workflow runs count statistics Get workflow runs count statistics
""" """
args = _parse_workflow_run_count_args() args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility) # Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args.get("triggered_from")) WorkflowRunTriggeredFrom(args_model.triggered_from)
if args.get("triggered_from") if args_model.triggered_from
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )

View File

@ -1,5 +1,6 @@
from flask import abort, jsonify from flask import abort, jsonify, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from controllers.console import console_ns from controllers.console import console_ns
@ -7,12 +8,31 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import parse_time_range from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory from repositories.factory import DifyAPIRepositoryFactory
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowStatisticQuery(BaseModel):
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
WorkflowStatisticQuery.__name__,
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations") @console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource): class WorkflowDailyRunsStatistic(Resource):
@ -24,9 +44,7 @@ class WorkflowDailyRunsStatistic(Resource):
@console_ns.doc("get_workflow_daily_runs_statistic") @console_ns.doc("get_workflow_daily_runs_statistic")
@console_ns.doc(description="Get workflow daily runs statistics") @console_ns.doc(description="Get workflow daily runs statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc( @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily runs statistics retrieved successfully") @console_ns.response(200, "Daily runs statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@ -35,17 +53,12 @@ class WorkflowDailyRunsStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
parser = ( args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -71,9 +84,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
@console_ns.doc("get_workflow_daily_terminals_statistic") @console_ns.doc("get_workflow_daily_terminals_statistic")
@console_ns.doc(description="Get workflow daily terminals statistics") @console_ns.doc(description="Get workflow daily terminals statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc( @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily terminals statistics retrieved successfully") @console_ns.response(200, "Daily terminals statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@ -82,17 +93,12 @@ class WorkflowDailyTerminalsStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
parser = ( args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -118,9 +124,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
@console_ns.doc("get_workflow_daily_token_cost_statistic") @console_ns.doc("get_workflow_daily_token_cost_statistic")
@console_ns.doc(description="Get workflow daily token cost statistics") @console_ns.doc(description="Get workflow daily token cost statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc( @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Daily token cost statistics retrieved successfully") @console_ns.response(200, "Daily token cost statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@ -129,17 +133,12 @@ class WorkflowDailyTokenCostStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
parser = ( args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -165,9 +164,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
@console_ns.doc("get_workflow_average_app_interaction_statistic") @console_ns.doc("get_workflow_average_app_interaction_statistic")
@console_ns.doc(description="Get workflow average app interaction statistics") @console_ns.doc(description="Get workflow average app interaction statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @console_ns.doc(params={"app_id": "Application ID"})
@console_ns.doc( @console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"}
)
@console_ns.response(200, "Average app interaction statistics retrieved successfully") @console_ns.response(200, "Average app interaction statistics retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@ -176,17 +173,12 @@ class WorkflowAverageAppInteractionStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
parser = ( args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone) start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))

View File

@ -114,7 +114,7 @@ class AppTriggersApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/trigger-enable") @console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource): class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True) @console_ns.expect(console_ns.models[ParserEnable.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,28 +1,53 @@
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone from libs.helper import EmailStr, timezone
from models import AccountStatus from models import AccountStatus
from services.account_service import AccountService, RegisterService from services.account_service import RegisterService
active_check_parser = ( DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address") class ActivateCheckQuery(BaseModel):
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token") workspace_id: str | None = Field(default=None)
) email: EmailStr | None = Field(default=None)
token: str
class ActivatePayload(BaseModel):
workspace_id: str | None = Field(default=None)
email: EmailStr | None = Field(default=None)
token: str
name: str = Field(..., max_length=30)
interface_language: str = Field(...)
timezone: str = Field(...)
@field_validator("interface_language")
@classmethod
def validate_lang(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_tz(cls, value: str) -> str:
return timezone(value)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/activate/check") @console_ns.route("/activate/check")
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
@console_ns.doc("check_activation_token") @console_ns.doc("check_activation_token")
@console_ns.doc(description="Check if activation token is valid") @console_ns.doc(description="Check if activation token is valid")
@console_ns.expect(active_check_parser) @console_ns.expect(console_ns.models[ActivateCheckQuery.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Success", "Success",
@ -35,11 +60,11 @@ class ActivateCheckApi(Resource):
), ),
) )
def get(self): def get(self):
args = active_check_parser.parse_args() args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args["workspace_id"] workspaceId = args.workspace_id
reg_email = args["email"] reg_email = args.email
token = args["token"] token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation: if invitation:
@ -56,22 +81,11 @@ class ActivateCheckApi(Resource):
return {"is_valid": False} return {"is_valid": False}
active_parser = (
reqparse.RequestParser()
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
.add_argument("email", type=email, required=False, nullable=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
)
@console_ns.route("/activate") @console_ns.route("/activate")
class ActivateApi(Resource): class ActivateApi(Resource):
@console_ns.doc("activate_account") @console_ns.doc("activate_account")
@console_ns.doc(description="Activate account with invitation token") @console_ns.doc(description="Activate account with invitation token")
@console_ns.expect(active_parser) @console_ns.expect(console_ns.models[ActivatePayload.__name__])
@console_ns.response( @console_ns.response(
200, 200,
"Account activated successfully", "Account activated successfully",
@ -79,30 +93,27 @@ class ActivateApi(Resource):
"ActivationResponse", "ActivationResponse",
{ {
"result": fields.String(description="Operation result"), "result": fields.String(description="Operation result"),
"data": fields.Raw(description="Login token data"),
}, },
), ),
) )
@console_ns.response(400, "Already activated or invalid token") @console_ns.response(400, "Already activated or invalid token")
def post(self): def post(self):
args = active_parser.parse_args() args = ActivatePayload.model_validate(console_ns.payload)
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
if invitation is None: if invitation is None:
raise AlreadyActivateError() raise AlreadyActivateError()
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) RegisterService.revoke_token(args.workspace_id, args.email, args.token)
account = invitation["account"] account = invitation["account"]
account.name = args["name"] account.name = args.name
account.interface_language = args["interface_language"] account.interface_language = args.interface_language
account.timezone = args["timezone"] account.timezone = args.timezone
account.interface_theme = "light" account.interface_theme = "light"
account.status = AccountStatus.ACTIVE account.status = AccountStatus.ACTIVE
account.initialized_at = naive_utc_now() account.initialized_at = naive_utc_now()
db.session.commit() db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) return {"result": "success"}
return {"result": "success", "data": token_pair.model_dump()}

View File

@ -1,12 +1,26 @@
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
from ..wraps import account_initialization_required, setup_required from .. import console_ns
from ..auth.error import ApiKeyAuthFailedError
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ApiKeyAuthBindingPayload(BaseModel):
category: str = Field(...)
provider: str = Field(...)
credentials: dict = Field(...)
console_ns.schema_model(
ApiKeyAuthBindingPayload.__name__,
ApiKeyAuthBindingPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/api-key-auth/data-source") @console_ns.route("/api-key-auth/data-source")
@ -40,19 +54,15 @@ class ApiKeyAuthDataSourceBinding(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@is_admin_or_owner_required @is_admin_or_owner_required
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
def post(self): def post(self):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = ( payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
reqparse.RequestParser() data = payload.model_dump()
.add_argument("category", type=str, required=True, nullable=False, location="json") ApiKeyAuthService.validate_api_key_auth_args(data)
.add_argument("provider", type=str, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args)
try: try:
ApiKeyAuthService.create_provider_auth(current_tenant_id, args) ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
except Exception as e: except Exception as e:
raise ApiKeyAuthFailedError(str(e)) raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -5,12 +5,11 @@ from flask import current_app, redirect, request
from flask_restx import Resource, fields from flask_restx import Resource, fields
from configs import dify_config from configs import dify_config
from controllers.console import console_ns
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import login_required from libs.login import login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,5 +1,6 @@
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -14,16 +15,45 @@ from controllers.console.auth.error import (
InvalidTokenError, InvalidTokenError,
PasswordMismatchError, PasswordMismatchError,
) )
from controllers.console.error import AccountInFreezeError, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, email_register_enabled, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models import Account from models import Account
from services.account_service import AccountService from services.account_service import AccountService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
from ..error import AccountInFreezeError, EmailSendIpLimitError
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EmailRegisterSendPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
language: str | None = Field(default=None, description="Language code")
class EmailRegisterValidityPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class EmailRegisterResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/email-register/send-email") @console_ns.route("/email-register/send-email")
class EmailRegisterSendEmailApi(Resource): class EmailRegisterSendEmailApi(Resource):
@ -31,27 +61,22 @@ class EmailRegisterSendEmailApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = ( args = EmailRegisterSendPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
language = "en-US" language = "en-US"
if args["language"] in languages: if args.language in languages:
language = args["language"] language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError() raise AccountInFreezeError()
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None token = None
token = AccountService.send_email_register_email(email=args["email"], account=account, language=language) token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
return {"result": "success", "data": token} return {"result": "success", "data": token}
@ -61,40 +86,34 @@ class EmailRegisterCheckApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = ( args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args["email"] user_email = args.email
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args["email"]) is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
if is_email_register_error_rate_limit: if is_email_register_error_rate_limit:
raise EmailRegisterLimitError() raise EmailRegisterLimitError()
token_data = AccountService.get_email_register_data(args["token"]) token_data = AccountService.get_email_register_data(args.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): if user_email != token_data.get("email"):
raise InvalidEmailError() raise InvalidEmailError()
if args["code"] != token_data.get("code"): if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args["email"]) AccountService.add_email_register_error_rate_limit(args.email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
AccountService.revoke_email_register_token(args["token"]) AccountService.revoke_email_register_token(args.token)
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_email_register_token( _, new_token = AccountService.generate_email_register_token(
user_email, code=args["code"], additional_data={"phase": "register"} user_email, code=args.code, additional_data={"phase": "register"}
) )
AccountService.reset_email_register_error_rate_limit(args["email"]) AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -104,20 +123,14 @@ class EmailRegisterResetApi(Resource):
@email_password_login_enabled @email_password_login_enabled
@email_register_enabled @email_register_enabled
def post(self): def post(self):
parser = ( args = EmailRegisterResetPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match # Validate passwords match
if args["new_password"] != args["password_confirm"]: if args.new_password != args.password_confirm:
raise PasswordMismatchError() raise PasswordMismatchError()
# Validate token and get register data # Validate token and get register data
register_data = AccountService.get_email_register_data(args["token"]) register_data = AccountService.get_email_register_data(args.token)
if not register_data: if not register_data:
raise InvalidTokenError() raise InvalidTokenError()
# Must use token in reset phase # Must use token in reset phase
@ -125,7 +138,7 @@ class EmailRegisterResetApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
# Revoke token to prevent reuse # Revoke token to prevent reuse
AccountService.revoke_email_register_token(args["token"]) AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "") email = register_data.get("email", "")
@ -135,7 +148,7 @@ class EmailRegisterResetApi(Resource):
if account: if account:
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
else: else:
account = self._create_new_account(email, args["password_confirm"]) account = self._create_new_account(email, args.password_confirm)
if not account: if not account:
raise AccountNotFoundError() raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))

View File

@ -2,7 +2,8 @@ import base64
import secrets import secrets
from flask import request from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -18,26 +19,46 @@ from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, setup_required from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models import Account from models import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/forgot-password") @console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@console_ns.doc("send_forgot_password_email") @console_ns.doc("send_forgot_password_email")
@console_ns.doc(description="Send password reset email") @console_ns.doc(description="Send password reset email")
@console_ns.expect( @console_ns.expect(console_ns.models[ForgotPasswordSendPayload.__name__])
console_ns.model(
"ForgotPasswordEmailRequest",
{
"email": fields.String(required=True, description="Email address"),
"language": fields.String(description="Language for email (zh-Hans/en-US)"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Email sent successfully", "Email sent successfully",
@ -54,28 +75,23 @@ class ForgotPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = ( args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans": if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = AccountService.send_reset_password_email( token = AccountService.send_reset_password_email(
account=account, account=account,
email=args["email"], email=args.email,
language=language, language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register, is_allow_register=FeatureService.get_system_features().is_allow_register,
) )
@ -87,16 +103,7 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@console_ns.doc("check_forgot_password_code") @console_ns.doc("check_forgot_password_code")
@console_ns.doc(description="Verify password reset code") @console_ns.doc(description="Verify password reset code")
@console_ns.expect( @console_ns.expect(console_ns.models[ForgotPasswordCheckPayload.__name__])
console_ns.model(
"ForgotPasswordCheckRequest",
{
"email": fields.String(required=True, description="Email address"),
"code": fields.String(required=True, description="Verification code"),
"token": fields.String(required=True, description="Reset token"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Code verified successfully", "Code verified successfully",
@ -113,40 +120,34 @@ class ForgotPasswordCheckApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = ( args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
user_email = args["email"] user_email = args.email
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
if is_forgot_password_error_rate_limit: if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError() raise EmailPasswordResetLimitError()
token_data = AccountService.get_reset_password_data(args["token"]) token_data = AccountService.get_reset_password_data(args.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): if user_email != token_data.get("email"):
raise InvalidEmailError() raise InvalidEmailError()
if args["code"] != token_data.get("code"): if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args["email"]) AccountService.add_forgot_password_error_rate_limit(args.email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"]) AccountService.revoke_reset_password_token(args.token)
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token( _, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"} user_email, code=args.code, additional_data={"phase": "reset"}
) )
AccountService.reset_forgot_password_error_rate_limit(args["email"]) AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
@ -154,16 +155,7 @@ class ForgotPasswordCheckApi(Resource):
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@console_ns.doc("reset_password") @console_ns.doc("reset_password")
@console_ns.doc(description="Reset password with verification token") @console_ns.doc(description="Reset password with verification token")
@console_ns.expect( @console_ns.expect(console_ns.models[ForgotPasswordResetPayload.__name__])
console_ns.model(
"ForgotPasswordResetRequest",
{
"token": fields.String(required=True, description="Verification token"),
"new_password": fields.String(required=True, description="New password"),
"password_confirm": fields.String(required=True, description="Password confirmation"),
},
)
)
@console_ns.response( @console_ns.response(
200, 200,
"Password reset successfully", "Password reset successfully",
@ -173,20 +165,14 @@ class ForgotPasswordResetApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
parser = ( args = ForgotPasswordResetPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("token", type=str, required=True, nullable=False, location="json")
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
)
args = parser.parse_args()
# Validate passwords match # Validate passwords match
if args["new_password"] != args["password_confirm"]: if args.new_password != args.password_confirm:
raise PasswordMismatchError() raise PasswordMismatchError()
# Validate token and get reset data # Validate token and get reset data
reset_data = AccountService.get_reset_password_data(args["token"]) reset_data = AccountService.get_reset_password_data(args.token)
if not reset_data: if not reset_data:
raise InvalidTokenError() raise InvalidTokenError()
# Must use token in reset phase # Must use token in reset phase
@ -194,11 +180,11 @@ class ForgotPasswordResetApi(Resource):
raise InvalidTokenError() raise InvalidTokenError()
# Revoke token to prevent reuse # Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"]) AccountService.revoke_reset_password_token(args.token)
# Generate secure salt and hash password # Generate secure salt and hash password
salt = secrets.token_bytes(16) salt = secrets.token_bytes(16)
password_hashed = hash_password(args["new_password"], salt) password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "") email = reset_data.get("email", "")

View File

@ -1,6 +1,7 @@
import flask_login import flask_login
from flask import make_response, request from flask import make_response, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
import services import services
from configs import dify_config from configs import dify_config
@ -21,9 +22,14 @@ from controllers.console.error import (
NotAllowedCreateWorkspace, NotAllowedCreateWorkspace,
WorkspacesLimitExceeded, WorkspacesLimitExceeded,
) )
from controllers.console.wraps import email_password_login_enabled, setup_required from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
email_password_login_enabled,
setup_required,
)
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip from libs.helper import EmailStr, extract_remote_ip
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from libs.token import ( from libs.token import (
clear_access_token_from_cookie, clear_access_token_from_cookie,
@ -40,6 +46,36 @@ from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
class EmailPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class EmailCodeLoginPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
language: str | None = Field(default=None)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(LoginPayload)
reg(EmailPayload)
reg(EmailCodeLoginPayload)
@console_ns.route("/login") @console_ns.route("/login")
class LoginApi(Resource): class LoginApi(Resource):
@ -47,41 +83,37 @@ class LoginApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@console_ns.expect(console_ns.models[LoginPayload.__name__])
@decrypt_password_field
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
parser = ( args = LoginPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("password", type=str, required=True, location="json")
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
.add_argument("invite_token", type=str, required=False, default=None, location="json")
)
args = parser.parse_args()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
raise AccountInFreezeError() raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
if is_login_error_rate_limit: if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError() raise EmailPasswordLoginLimitError()
invitation = args["invite_token"] # TODO: why invitation is re-assigned with different type?
invitation = args.invite_token # type: ignore
if invitation: if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) invitation = RegisterService.get_invitation_if_token_valid(None, args.email, invitation) # type: ignore
try: try:
if invitation: if invitation:
data = invitation.get("data", {}) data = invitation.get("data", {}) # type: ignore
invitee_email = data.get("email") if data else None invitee_email = data.get("email") if data else None
if invitee_email != args["email"]: if invitee_email != args.email:
raise InvalidEmailError() raise InvalidEmailError()
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"]) account = AccountService.authenticate(args.email, args.password, args.invite_token)
else: else:
account = AccountService.authenticate(args["email"], args["password"]) account = AccountService.authenticate(args.email, args.password)
except services.errors.account.AccountLoginError: except services.errors.account.AccountLoginError:
raise AccountBannedError() raise AccountBannedError()
except services.errors.account.AccountPasswordError: except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"]) AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError() raise AuthenticationFailedError()
# SELF_HOSTED only have one workspace # SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
@ -97,7 +129,7 @@ class LoginApi(Resource):
} }
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body # Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"}) response = make_response({"result": "success"})
@ -134,25 +166,21 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource): class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self): def post(self):
parser = ( args = EmailPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
if args["language"] is not None and args["language"] == "zh-Hans": if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args["email"]) account = AccountService.get_user_through_email(args.email)
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
token = AccountService.send_reset_password_email( token = AccountService.send_reset_password_email(
email=args["email"], email=args.email,
account=account, account=account,
language=language, language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register, is_allow_register=FeatureService.get_system_features().is_allow_register,
@ -164,30 +192,26 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.route("/email-code-login") @console_ns.route("/email-code-login")
class EmailCodeLoginSendEmailApi(Resource): class EmailCodeLoginSendEmailApi(Resource):
@setup_required @setup_required
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self): def post(self):
parser = ( args = EmailPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
if args["language"] is not None and args["language"] == "zh-Hans": if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
try: try:
account = AccountService.get_user_through_email(args["email"]) account = AccountService.get_user_through_email(args.email)
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
if account is None: if account is None:
if FeatureService.get_system_features().is_allow_register: if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language) token = AccountService.send_email_code_login_email(email=args.email, language=language)
else: else:
raise AccountNotFound() raise AccountNotFound()
else: else:
@ -199,30 +223,25 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.route("/email-code-login/validity") @console_ns.route("/email-code-login/validity")
class EmailCodeLoginApi(Resource): class EmailCodeLoginApi(Resource):
@setup_required @setup_required
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
@decrypt_code_field
def post(self): def post(self):
parser = ( args = EmailCodeLoginPayload.model_validate(console_ns.payload)
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
)
args = parser.parse_args()
user_email = args["email"] user_email = args.email
language = args["language"] language = args.language
token_data = AccountService.get_email_code_login_data(args["token"]) token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if token_data["email"] != args["email"]: if token_data["email"] != args.email:
raise InvalidEmailError() raise InvalidEmailError()
if token_data["code"] != args["code"]: if token_data["code"] != args.code:
raise EmailCodeError() raise EmailCodeError()
AccountService.revoke_email_code_login_token(args["token"]) AccountService.revoke_email_code_login_token(args.token)
try: try:
account = AccountService.get_user_through_email(user_email) account = AccountService.get_user_through_email(user_email)
except AccountRegisterError: except AccountRegisterError:
@ -255,7 +274,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError: except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"]) AccountService.reset_login_error_rate_limit(args.email)
# Create response with cookies instead of returning tokens in body # Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"}) response = make_response({"result": "success"})

View File

@ -3,7 +3,8 @@ from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar from typing import Concatenate, ParamSpec, TypeVar
from flask import jsonify, request from flask import jsonify, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
@ -20,15 +21,34 @@ R = TypeVar("R")
T = TypeVar("T") T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
class OAuthProviderRequest(BaseModel):
client_id: str
redirect_uri: str
class OAuthTokenRequest(BaseModel):
client_id: str
grant_type: str
code: str | None = None
client_secret: str | None = None
redirect_uri: str | None = None
refresh_token: str | None = None
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]): def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view) @wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs): def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json") json_data = request.get_json()
parsed_args = parser.parse_args() if json_data is None:
client_id = parsed_args.get("client_id")
if not client_id:
raise BadRequest("client_id is required") raise BadRequest("client_id is required")
payload = OAuthClientPayload.model_validate(json_data)
client_id = payload.client_id
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id) oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app: if not oauth_provider_app:
raise NotFound("client_id is invalid") raise NotFound("client_id is invalid")
@ -89,9 +109,8 @@ class OAuthServerAppApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json") payload = OAuthProviderRequest.model_validate(request.get_json())
parsed_args = parser.parse_args() redirect_uri = payload.redirect_uri
redirect_uri = parsed_args.get("redirect_uri")
# check if redirect_uri is valid # check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris: if redirect_uri not in oauth_provider_app.redirect_uris:
@ -130,33 +149,25 @@ class OAuthServerUserTokenApi(Resource):
@setup_required @setup_required
@oauth_server_client_id_required @oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp): def post(self, oauth_provider_app: OAuthProviderApp):
parser = ( payload = OAuthTokenRequest.model_validate(request.get_json())
reqparse.RequestParser()
.add_argument("grant_type", type=str, required=True, location="json")
.add_argument("code", type=str, required=False, location="json")
.add_argument("client_secret", type=str, required=False, location="json")
.add_argument("redirect_uri", type=str, required=False, location="json")
.add_argument("refresh_token", type=str, required=False, location="json")
)
parsed_args = parser.parse_args()
try: try:
grant_type = OAuthGrantType(parsed_args["grant_type"]) grant_type = OAuthGrantType(payload.grant_type)
except ValueError: except ValueError:
raise BadRequest("invalid grant_type") raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE: if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]: if not payload.code:
raise BadRequest("code is required") raise BadRequest("code is required")
if parsed_args["client_secret"] != oauth_provider_app.client_secret: if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid") raise BadRequest("client_secret is invalid")
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris: if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid") raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token( access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id grant_type, code=payload.code, client_id=oauth_provider_app.client_id
) )
return jsonable_encoder( return jsonable_encoder(
{ {
@ -167,11 +178,11 @@ class OAuthServerUserTokenApi(Resource):
} }
) )
elif grant_type == OAuthGrantType.REFRESH_TOKEN: elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not parsed_args["refresh_token"]: if not payload.refresh_token:
raise BadRequest("refresh_token is required") raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token( access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
) )
return jsonable_encoder( return jsonable_encoder(
{ {

View File

@ -1,6 +1,9 @@
import base64 import base64
from typing import Literal
from flask_restx import Resource, fields, reqparse from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
@ -9,6 +12,21 @@ from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SubscriptionQuery(BaseModel):
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
interval: Literal["month", "year"] = Field(..., description="Billing interval")
class PartnerTenantsPayload(BaseModel):
click_id: str = Field(..., description="Click Id from partner referral link")
for model in (SubscriptionQuery, PartnerTenantsPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/billing/subscription") @console_ns.route("/billing/subscription")
class Subscription(Resource): class Subscription(Resource):
@ -18,20 +36,9 @@ class Subscription(Resource):
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
reqparse.RequestParser()
.add_argument(
"plan",
type=str,
required=True,
location="args",
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
)
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
)
args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id) return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@console_ns.route("/billing/invoices") @console_ns.route("/billing/invoices")
@ -65,11 +72,10 @@ class PartnerTenants(Resource):
@only_edition_cloud @only_edition_cloud
def put(self, partner_key: str): def put(self, partner_key: str):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
args = parser.parse_args()
try: try:
click_id = args["click_id"] args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8") decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
except Exception: except Exception:
raise BadRequest("Invalid partner_key") raise BadRequest("Invalid partner_key")

View File

@ -1,5 +1,6 @@
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -9,16 +10,28 @@ from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required from ..wraps import account_initialization_required, only_edition_cloud, setup_required
class ComplianceDownloadQuery(BaseModel):
doc_name: str = Field(..., description="Compliance document name")
console_ns.schema_model(
ComplianceDownloadQuery.__name__,
ComplianceDownloadQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/compliance/download") @console_ns.route("/compliance/download")
class ComplianceApi(Resource): class ComplianceApi(Resource):
@console_ns.expect(console_ns.models[ComplianceDownloadQuery.__name__])
@console_ns.doc("download_compliance_document")
@console_ns.doc(description="Get compliance document download link")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args") args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device") device_info = request.headers.get("User-Agent", "Unknown device")

View File

@ -1,15 +1,15 @@
import json import json
from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import Any, cast
from flask import request from flask import request
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.common.schema import register_schema_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
@ -25,6 +25,19 @@ from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task from tasks.document_indexing_sync_task import document_indexing_sync_task
from .. import console_ns
from ..wraps import account_initialization_required, setup_required
class NotionEstimatePayload(BaseModel):
notion_info_list: list[dict[str, Any]]
process_rule: dict[str, Any]
doc_form: str = Field(default="text_model")
doc_language: str = Field(default="English")
register_schema_model(console_ns, NotionEstimatePayload)
@console_ns.route( @console_ns.route(
"/data-source/integrates", "/data-source/integrates",
@ -127,6 +140,18 @@ class DataSourceNotionListApi(Resource):
credential_id = request.args.get("credential_id", default=None, type=str) credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id: if not credential_id:
raise ValueError("Credential id is required.") raise ValueError("Credential id is required.")
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
datasource_parameters = {}
if datasource_parameters_str:
try:
datasource_parameters = json.loads(datasource_parameters_str)
if not isinstance(datasource_parameters, dict):
raise ValueError("datasource_parameters must be a JSON object.")
except json.JSONDecodeError:
raise ValueError("Invalid datasource_parameters JSON format.")
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials( credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
@ -174,7 +199,7 @@ class DataSourceNotionListApi(Resource):
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages( datasource_runtime.get_online_document_pages(
user_id=current_user.id, user_id=current_user.id,
datasource_parameters={}, datasource_parameters=datasource_parameters,
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) )
) )
@ -205,14 +230,14 @@ class DataSourceNotionListApi(Resource):
@console_ns.route( @console_ns.route(
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview", "/notion/pages/<uuid:page_id>/<string:page_type>/preview",
"/datasets/notion-indexing-estimate", "/datasets/notion-indexing-estimate",
) )
class DataSourceNotionApi(Resource): class DataSourceNotionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, workspace_id, page_id, page_type): def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str) credential_id = request.args.get("credential_id", default=None, type=str)
@ -226,11 +251,10 @@ class DataSourceNotionApi(Resource):
plugin_id="langgenius/notion_datasource", plugin_id="langgenius/notion_datasource",
) )
workspace_id = str(workspace_id)
page_id = str(page_id) page_id = str(page_id)
extractor = NotionExtractor( extractor = NotionExtractor(
notion_workspace_id=workspace_id, notion_workspace_id="",
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type, notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"), notion_access_token=credential.get("integration_secret"),
@ -243,20 +267,15 @@ class DataSourceNotionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
def post(self): def post(self):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
parser = ( payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump()
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
notion_info_list = args["notion_info_list"] notion_info_list = payload.notion_info_list
extract_settings = [] extract_settings = []
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"] workspace_id = notion_info["workspace_id"]

View File

@ -1,12 +1,14 @@
from typing import Any, cast from typing import Any, cast
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from configs import dify_config from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.apikey import ( from controllers.console.apikey import (
api_key_item_model, api_key_item_model,
@ -48,7 +50,6 @@ from fields.dataset_fields import (
) )
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID from models.provider_ids import ModelProviderID
@ -107,10 +108,75 @@ related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_mode
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy) related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_name(name: str) -> str: def _validate_indexing_technique(value: str | None) -> str | None:
if not name or len(name) < 1 or len(name) > 40: if value is None:
raise ValueError("Name must be between 1 to 40 characters.") return value
return name if value not in Dataset.INDEXING_TECHNIQUE_LIST:
raise ValueError("Invalid indexing technique.")
return value
class DatasetCreatePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field("", max_length=400)
indexing_technique: str | None = None
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
provider: str = "vendor"
external_knowledge_api_id: str | None = None
external_knowledge_id: str | None = None
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
@field_validator("provider")
@classmethod
def validate_provider(cls, value: str) -> str:
if value not in Dataset.PROVIDER_LIST:
raise ValueError("Invalid provider.")
return value
class DatasetUpdatePayload(BaseModel):
name: str | None = Field(None, min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
permission: DatasetPermissionEnum | None = None
indexing_technique: str | None = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
icon_info: dict[str, Any] | None = None
is_multimodal: bool | None = False
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str | None) -> str | None:
return _validate_indexing_technique(value)
class IndexingEstimatePayload(BaseModel):
info_list: dict[str, Any]
process_rule: dict[str, Any]
indexing_technique: str
doc_form: str = "text_model"
dataset_id: str | None = None
doc_language: str = "English"
@field_validator("indexing_technique")
@classmethod
def validate_indexing(cls, value: str) -> str:
result = _validate_indexing_technique(value)
if result is None:
raise ValueError("indexing_technique is required.")
return result
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]: def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -157,6 +223,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.COUCHBASE, VectorType.COUCHBASE,
VectorType.OPENGAUSS, VectorType.OPENGAUSS,
VectorType.OCEANBASE, VectorType.OCEANBASE,
VectorType.SEEKDB,
VectorType.TABLESTORE, VectorType.TABLESTORE,
VectorType.HUAWEI_CLOUD, VectorType.HUAWEI_CLOUD,
VectorType.TENCENT, VectorType.TENCENT,
@ -164,6 +231,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.CLICKZETTA, VectorType.CLICKZETTA,
VectorType.BAIDU, VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL, VectorType.ALIBABACLOUD_MYSQL,
VectorType.IRIS,
} }
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
@ -255,20 +323,7 @@ class DatasetListApi(Resource):
@console_ns.doc("create_dataset") @console_ns.doc("create_dataset")
@console_ns.doc(description="Create a new dataset") @console_ns.doc(description="Create a new dataset")
@console_ns.expect( @console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
console_ns.model(
"CreateDatasetRequest",
{
"name": fields.String(required=True, description="Dataset name (1-40 characters)"),
"description": fields.String(description="Dataset description (max 400 characters)"),
"indexing_technique": fields.String(description="Indexing technique"),
"permission": fields.String(description="Dataset permission"),
"provider": fields.String(description="Provider"),
"external_knowledge_api_id": fields.String(description="External knowledge API ID"),
"external_knowledge_id": fields.String(description="External knowledge ID"),
},
)
)
@console_ns.response(201, "Dataset created successfully") @console_ns.response(201, "Dataset created successfully")
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@ -276,52 +331,7 @@ class DatasetListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = ( payload = DatasetCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
nullable=True,
required=False,
)
.add_argument(
"provider",
type=str,
nullable=True,
choices=Dataset.PROVIDER_LIST,
required=False,
default="vendor",
)
.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -331,14 +341,14 @@ class DatasetListApi(Resource):
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
name=args["name"], name=payload.name,
description=args["description"], description=payload.description,
indexing_technique=args["indexing_technique"], indexing_technique=payload.indexing_technique,
account=current_user, account=current_user,
permission=DatasetPermissionEnum.ONLY_ME, permission=payload.permission or DatasetPermissionEnum.ONLY_ME,
provider=args["provider"], provider=payload.provider,
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=args["external_knowledge_id"], external_knowledge_id=payload.external_knowledge_id,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -399,18 +409,7 @@ class DatasetApi(Resource):
@console_ns.doc("update_dataset") @console_ns.doc("update_dataset")
@console_ns.doc(description="Update dataset details") @console_ns.doc(description="Update dataset details")
@console_ns.expect( @console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
console_ns.model(
"UpdateDatasetRequest",
{
"name": fields.String(description="Dataset name"),
"description": fields.String(description="Dataset description"),
"permission": fields.String(description="Dataset permission"),
"indexing_technique": fields.String(description="Indexing technique"),
"external_retrieval_model": fields.Raw(description="External retrieval model settings"),
},
)
)
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model) @console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
@console_ns.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@ -424,93 +423,25 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
parser = ( payload = DatasetUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
.add_argument(
"permission",
type=str,
location="json",
choices=(
DatasetPermissionEnum.ONLY_ME,
DatasetPermissionEnum.ALL_TEAM,
DatasetPermissionEnum.PARTIAL_TEAM,
),
help="Invalid permission.",
)
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
)
args = parser.parse_args()
data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting # check embedding model setting
if ( if (
data.get("indexing_technique") == "high_quality" payload.indexing_technique == "high_quality"
and data.get("embedding_model_provider") is not None and payload.embedding_model_provider is not None
and data.get("embedding_model") is not None and payload.embedding_model is not None
): ):
DatasetService.check_embedding_model_setting( is_multimodal = DatasetService.check_is_multimodal_model(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model
) )
payload.is_multimodal = is_multimodal
payload_data = payload.model_dump(exclude_unset=True)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission( DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list") current_user, dataset, payload.permission, payload.partial_member_list
) )
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -518,15 +449,10 @@ class DatasetApi(Resource):
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_tenant_id tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members # clear partial member list when permission is only_me or all_team_members
elif ( elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str) DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
@ -615,24 +541,10 @@ class DatasetIndexingEstimateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[IndexingEstimatePayload.__name__])
def post(self): def post(self):
parser = ( payload = IndexingEstimatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump()
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument(
"indexing_technique",
type=str,
required=True,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
location="json",
)
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)

View File

@ -6,31 +6,14 @@ from typing import Literal, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from sqlalchemy import asc, desc, select from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
from core.errors.error import ( from core.errors.error import (
LLMBadRequestError, LLMBadRequestError,
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
@ -55,10 +38,30 @@ from fields.document_fields import (
) )
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from ..datasets.error import (
ArchivedDocumentImmutableError,
DocumentAlreadyFinishedError,
DocumentIndexingError,
IndexingEstimateError,
InvalidActionError,
InvalidMetadataError,
)
from ..wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -93,6 +96,24 @@ dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(docume
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy) dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentRetryPayload(BaseModel):
document_ids: list[str]
class DocumentRenamePayload(BaseModel):
name: str
register_schema_models(
console_ns,
KnowledgeConfig,
ProcessRule,
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
)
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -201,8 +222,9 @@ class DatasetDocumentListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id: str): def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
@ -310,6 +332,7 @@ class DatasetDocumentListApi(Resource):
@marshal_with(dataset_and_document_model) @marshal_with(dataset_and_document_model)
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -328,23 +351,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = ( knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
)
.add_argument("data_source", type=dict, required=False, location="json")
.add_argument("process_rule", type=dict, required=False, location="json")
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
.add_argument("original_document_id", type=str, required=False, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique: if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@ -390,17 +397,7 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
@console_ns.doc("init_dataset") @console_ns.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents") @console_ns.doc(description="Initialize dataset with documents")
@console_ns.expect( @console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
console_ns.model(
"DatasetInitRequest",
{
"upload_file_id": fields.String(required=True, description="Upload file ID"),
"indexing_technique": fields.String(description="Indexing technique"),
"process_rule": fields.Raw(description="Processing rules"),
"data_source": fields.Raw(description="Data source configuration"),
},
)
)
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) @console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
@console_ns.response(400, "Invalid request parameters") @console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@ -415,27 +412,7 @@ class DatasetInitApi(Resource):
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
parser = ( knowledge_config = KnowledgeConfig.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"indexing_technique",
type=str,
choices=Dataset.INDEXING_TECHNIQUE_LIST,
required=True,
nullable=False,
location="json",
)
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.") raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@ -443,10 +420,14 @@ class DatasetInitApi(Resource):
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider=args["embedding_model_provider"], provider=knowledge_config.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=args["embedding_model"], model=knowledge_config.embedding_model,
) )
is_multimodal = DatasetService.check_is_multimodal_model(
current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model
)
knowledge_config.is_multimodal = is_multimodal
except InvokeAuthorizationError: except InvokeAuthorizationError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
@ -591,7 +572,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
datasource_type=DatasourceType.NOTION, datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate( notion_info=NotionInfo.model_validate(
{ {
"credential_id": data_source_info["credential_id"], "credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"], "notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"], "notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"], "notion_page_type": data_source_info["type"],
@ -1076,19 +1057,16 @@ class DocumentRetryApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
def post(self, dataset_id): def post(self, dataset_id):
"""retry document.""" """retry document."""
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
parser = reqparse.RequestParser().add_argument(
"document_ids", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
retry_documents = [] retry_documents = []
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
for document_id in args["document_ids"]: for document_id in payload.document_ids:
try: try:
document_id = str(document_id) document_id = str(document_id)
@ -1121,6 +1099,7 @@ class DocumentRenameApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(document_fields) @marshal_with(document_fields)
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
@ -1130,11 +1109,10 @@ class DocumentRenameApi(DocumentResource):
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_operator_permission(current_user, dataset) DatasetService.check_dataset_operator_permission(current_user, dataset)
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
try: try:
document = DocumentService.rename_document(dataset_id, document_id, args["name"]) document = DocumentService.rename_document(dataset_id, document_id, payload.name)
except services.errors.document.DocumentIndexingError: except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.") raise DocumentIndexingError("Cannot delete document during indexing.")

View File

@ -1,11 +1,13 @@
import uuid import uuid
from flask import request from flask import request
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import ( from controllers.console.datasets.error import (
@ -36,6 +38,58 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
hit_count_gte: int | None = None
enabled: str = Field(default="all")
keyword: str | None = None
page: int = Field(default=1, ge=1)
class SegmentCreatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
attachment_ids: list[str] | None = None
class SegmentUpdatePayload(BaseModel):
content: str
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
class BatchImportPayload(BaseModel):
upload_file_id: str
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]
register_schema_models(
console_ns,
SegmentListQuery,
SegmentCreatePayload,
SegmentUpdatePayload,
BatchImportPayload,
ChildChunkCreatePayload,
ChildChunkUpdatePayload,
ChildChunkBatchUpdatePayload,
)
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@setup_required @setup_required
@ -60,23 +114,18 @@ class DatasetDocumentSegmentListApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = ( args = SegmentListQuery.model_validate(
reqparse.RequestParser() {
.add_argument("limit", type=int, default=20, location="args") **request.args.to_dict(),
.add_argument("status", type=str, action="append", default=[], location="args") "status": request.args.getlist("status"),
.add_argument("hit_count_gte", type=int, default=None, location="args") }
.add_argument("enabled", type=str, default="all", location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
) )
args = parser.parse_args() page = args.page
limit = min(args.limit, 100)
page = args["page"] status_list = args.status
limit = min(args["limit"], 100) hit_count_gte = args.hit_count_gte
status_list = args["status"] keyword = args.keyword
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
query = ( query = (
select(DocumentSegment) select(DocumentSegment)
@ -96,10 +145,10 @@ class DatasetDocumentSegmentListApi(Resource):
if keyword: if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args["enabled"].lower() != "all": if args.enabled.lower() != "all":
if args["enabled"].lower() == "true": if args.enabled.lower() == "true":
query = query.where(DocumentSegment.enabled == True) query = query.where(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false": elif args.enabled.lower() == "false":
query = query.where(DocumentSegment.enabled == False) query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -210,6 +259,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -246,15 +296,10 @@ class DatasetDocumentSegmentAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = ( payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() payload_dict = payload.model_dump(exclude_none=True)
.add_argument("content", type=str, required=True, nullable=False, location="json") SegmentService.segment_create_args_validate(payload_dict, document)
.add_argument("answer", type=str, required=False, nullable=True, location="json") segment = SegmentService.create_segment(payload_dict, document, dataset)
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@ -265,6 +310,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -313,18 +359,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = ( payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() payload_dict = payload.model_dump(exclude_none=True)
.add_argument("content", type=str, required=True, nullable=False, location="json") SegmentService.segment_create_args_validate(payload_dict, document)
.add_argument("answer", type=str, required=False, nullable=True, location="json") segment = SegmentService.update_segment(
.add_argument("keywords", type=list, required=False, nullable=True, location="json") SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
) )
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required @setup_required
@ -377,6 +417,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -391,11 +432,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser().add_argument( payload = BatchImportPayload.model_validate(console_ns.payload or {})
"upload_file_id", type=str, required=True, nullable=False, location="json" upload_file_id = payload.upload_file_id
)
args = parser.parse_args()
upload_file_id = args["upload_file_id"]
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file: if not upload_file:
@ -446,6 +484,7 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
def post(self, dataset_id, document_id, segment_id): def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -491,13 +530,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try: try:
content = args["content"] payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset) child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@ -529,18 +564,17 @@ class ChildChunkAddApi(Resource):
) )
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
parser = ( args = SegmentListQuery.model_validate(
reqparse.RequestParser() {
.add_argument("limit", type=int, default=20, location="args") "limit": request.args.get("limit", default=20, type=int),
.add_argument("keyword", type=str, default=None, location="args") "keyword": request.args.get("keyword"),
.add_argument("page", type=int, default=1, location="args") "page": request.args.get("page", default=1, type=int),
}
) )
args = parser.parse_args() page = args.page
limit = min(args.limit, 100)
page = args["page"] keyword = args.keyword
limit = min(args["limit"], 100)
keyword = args["keyword"]
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return { return {
@ -588,14 +622,9 @@ class ChildChunkAddApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser().add_argument( payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {})
"chunks", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try: try:
chunks_data = args["chunks"] child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200 return {"data": marshal(child_chunks, child_chunk_fields)}, 200
@ -665,6 +694,7 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
def patch(self, dataset_id, document_id, segment_id, child_chunk_id): def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -711,13 +741,9 @@ class ChildChunkUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
args = parser.parse_args()
try: try:
content = args["content"] payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {})
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset) child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e: except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e)) raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200 return {"data": marshal(child_chunk, child_chunk_fields)}, 200

View File

@ -1,8 +1,10 @@
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, reqparse from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -71,10 +73,38 @@ except KeyError:
dataset_detail_model = _build_dataset_detail_model() dataset_detail_model = _build_dataset_detail_model()
def _validate_name(name: str) -> str: class ExternalKnowledgeApiPayload(BaseModel):
if not name or len(name) < 1 or len(name) > 100: name: str = Field(..., min_length=1, max_length=40)
raise ValueError("Name must be between 1 to 100 characters.") settings: dict[str, object]
return name
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
name: str = Field(..., min_length=1, max_length=40)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None
class ExternalHitTestingPayload(BaseModel):
query: str
external_retrieval_model: dict[str, object] | None = None
metadata_filtering_conditions: dict[str, object] | None = None
class BedrockRetrievalPayload(BaseModel):
retrieval_setting: dict[str, object]
query: str
knowledge_id: str
register_schema_models(
console_ns,
ExternalKnowledgeApiPayload,
ExternalDatasetCreatePayload,
ExternalHitTestingPayload,
BedrockRetrievalPayload,
)
@console_ns.route("/datasets/external-knowledge-api") @console_ns.route("/datasets/external-knowledge-api")
@ -113,28 +143,12 @@ class ExternalApiTemplateListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def post(self): def post(self):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"]) ExternalDatasetService.validate_api_list(payload.settings)
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -142,7 +156,7 @@ class ExternalApiTemplateListApi(Resource):
try: try:
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
tenant_id=current_tenant_id, user_id=current_user.id, args=args tenant_id=current_tenant_id, user_id=current_user.id, args=payload.model_dump()
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -171,35 +185,19 @@ class ExternalApiTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id): def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id) external_knowledge_api_id = str(external_knowledge_api_id)
parser = ( payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() ExternalDatasetService.validate_api_list(payload.settings)
.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
user_id=current_user.id, user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id, external_knowledge_api_id=external_knowledge_api_id,
args=args, args=payload.model_dump(),
) )
return external_knowledge_api.to_dict(), 200 return external_knowledge_api.to_dict(), 200
@ -240,17 +238,7 @@ class ExternalApiUseCheckApi(Resource):
class ExternalDatasetCreateApi(Resource): class ExternalDatasetCreateApi(Resource):
@console_ns.doc("create_external_dataset") @console_ns.doc("create_external_dataset")
@console_ns.doc(description="Create external knowledge dataset") @console_ns.doc(description="Create external knowledge dataset")
@console_ns.expect( @console_ns.expect(console_ns.models[ExternalDatasetCreatePayload.__name__])
console_ns.model(
"CreateExternalDatasetRequest",
{
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
"external_knowledge_id": fields.String(required=True, description="External knowledge ID"),
"name": fields.String(required=True, description="Dataset name"),
"description": fields.String(description="Dataset description"),
},
)
)
@console_ns.response(201, "External dataset created successfully", dataset_detail_model) @console_ns.response(201, "External dataset created successfully", dataset_detail_model)
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@console_ns.response(403, "Permission denied") @console_ns.response(403, "Permission denied")
@ -261,22 +249,8 @@ class ExternalDatasetCreateApi(Resource):
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
parser = ( payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
.add_argument(
"name",
nullable=False,
required=True,
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
.add_argument("description", type=str, required=False, nullable=True, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -299,16 +273,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@console_ns.doc("test_external_knowledge_retrieval") @console_ns.doc("test_external_knowledge_retrieval")
@console_ns.doc(description="Test external knowledge retrieval for dataset") @console_ns.doc(description="Test external knowledge retrieval for dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[ExternalHitTestingPayload.__name__])
console_ns.model(
"ExternalHitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"external_retrieval_model": fields.Raw(description="External retrieval model configuration"),
},
)
)
@console_ns.response(200, "External hit testing completed successfully") @console_ns.response(200, "External hit testing completed successfully")
@console_ns.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@ -327,23 +292,16 @@ class ExternalKnowledgeHitTestingApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = ( payload = ExternalHitTestingPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() HitTestingService.hit_testing_args_check(payload.model_dump())
.add_argument("query", type=str, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
)
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
try: try:
response = HitTestingService.external_retrieve( response = HitTestingService.external_retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=payload.query,
account=current_user, account=current_user,
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=payload.external_retrieval_model,
metadata_filtering_conditions=args["metadata_filtering_conditions"], metadata_filtering_conditions=payload.metadata_filtering_conditions,
) )
return response return response
@ -356,33 +314,13 @@ class BedrockRetrievalApi(Resource):
# this api is only for internal testing # this api is only for internal testing
@console_ns.doc("bedrock_retrieval_test") @console_ns.doc("bedrock_retrieval_test")
@console_ns.doc(description="Bedrock retrieval test (internal use only)") @console_ns.doc(description="Bedrock retrieval test (internal use only)")
@console_ns.expect( @console_ns.expect(console_ns.models[BedrockRetrievalPayload.__name__])
console_ns.model(
"BedrockRetrievalTestRequest",
{
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
"query": fields.String(required=True, description="Query text"),
"knowledge_id": fields.String(required=True, description="Knowledge ID"),
},
)
)
@console_ns.response(200, "Bedrock retrieval test completed") @console_ns.response(200, "Bedrock retrieval test completed")
def post(self): def post(self):
parser = ( payload = BedrockRetrievalPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
.add_argument("knowledge_id", nullable=False, required=True, type=str)
)
args = parser.parse_args()
# Call the knowledge retrieval service # Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval( result = ExternalDatasetTestService.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"] payload.retrieval_setting, payload.query, payload.knowledge_id
) )
return result, 200 return result, 200

View File

@ -1,13 +1,17 @@
from flask_restx import Resource, fields from flask_restx import Resource
from controllers.console import console_ns from controllers.common.schema import register_schema_model
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from libs.login import login_required
from controllers.console.wraps import (
from .. import console_ns
from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from ..wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_rate_limit_check, cloud_edition_billing_rate_limit_check,
setup_required, setup_required,
) )
from libs.login import login_required
register_schema_model(console_ns, HitTestingPayload)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing") @console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
@ -15,17 +19,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval") @console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval") @console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect( @console_ns.expect(console_ns.models[HitTestingPayload.__name__])
console_ns.model(
"HitTestingRequest",
{
"query": fields.String(required=True, description="Query text for testing"),
"retrieval_model": fields.Raw(description="Retrieval model configuration"),
"top_k": fields.Integer(description="Number of top results to return"),
"score_threshold": fields.Float(description="Score threshold for filtering results"),
},
)
)
@console_ns.response(200, "Hit testing completed successfully") @console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(404, "Dataset not found") @console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters") @console_ns.response(400, "Invalid parameters")
@ -37,7 +31,8 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str) dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args() payload = HitTestingPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
self.hit_testing_args_check(args) self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args) return self.perform_hit_testing(dataset, args)

View File

@ -1,6 +1,8 @@
import logging import logging
from typing import Any
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
@ -27,6 +29,13 @@ from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HitTestingPayload(BaseModel):
query: str = Field(max_length=250)
retrieval_model: dict[str, Any] | None = None
external_retrieval_model: dict[str, Any] | None = None
attachment_ids: list[str] | None = None
class DatasetsHitTestingBase: class DatasetsHitTestingBase:
@staticmethod @staticmethod
def get_and_validate_dataset(dataset_id: str): def get_and_validate_dataset(dataset_id: str):
@ -43,14 +52,15 @@ class DatasetsHitTestingBase:
return dataset return dataset
@staticmethod @staticmethod
def hit_testing_args_check(args): def hit_testing_args_check(args: dict[str, Any]):
HitTestingService.hit_testing_args_check(args) HitTestingService.hit_testing_args_check(args)
@staticmethod @staticmethod
def parse_args(): def parse_args():
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()
.add_argument("query", type=str, location="json") .add_argument("query", type=str, required=False, location="json")
.add_argument("attachment_ids", type=list, required=False, location="json")
.add_argument("retrieval_model", type=dict, required=False, location="json") .add_argument("retrieval_model", type=dict, required=False, location="json")
.add_argument("external_retrieval_model", type=dict, required=False, location="json") .add_argument("external_retrieval_model", type=dict, required=False, location="json")
) )
@ -62,10 +72,11 @@ class DatasetsHitTestingBase:
try: try:
response = HitTestingService.retrieve( response = HitTestingService.retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=args.get("query"),
account=current_user, account=current_user,
retrieval_model=args["retrieval_model"], retrieval_model=args.get("retrieval_model"),
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=args.get("external_retrieval_model"),
attachment_ids=args.get("attachment_ids"),
limit=10, limit=10,
) )
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}

View File

@ -1,8 +1,10 @@
from typing import Literal from typing import Literal
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
@ -15,6 +17,14 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(console_ns, MetadataArgs, MetadataOperationData)
register_schema_model(console_ns, MetadataUpdatePayload)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata") @console_ns.route("/datasets/<uuid:dataset_id>/metadata")
class DatasetMetadataCreateApi(Resource): class DatasetMetadataCreateApi(Resource):
@setup_required @setup_required
@ -22,15 +32,10 @@ class DatasetMetadataCreateApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = ( metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("type", type=str, required=True, nullable=False, location="json")
.add_argument("name", type=str, required=True, nullable=False, location="json")
)
args = parser.parse_args()
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -60,11 +65,11 @@ class DatasetMetadataApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@marshal_with(dataset_metadata_fields) @marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id, metadata_id): def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json") payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args() name = payload.name
name = args["name"]
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
@ -131,6 +136,7 @@ class DocumentMetadataEditApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
def post(self, dataset_id): def post(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
@ -139,11 +145,7 @@ class DocumentMetadataEditApi(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
parser = reqparse.RequestParser().add_argument( metadata_args = MetadataOperationData.model_validate(console_ns.payload or {})
"operation_data", type=list, required=True, nullable=False, location="json"
)
args = parser.parse_args()
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args) MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -1,20 +1,63 @@
from typing import Any
from flask import make_response, redirect, request from flask import make_response, redirect, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.provider_ids import DatasourceProviderID from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
class DatasourceCredentialPayload(BaseModel):
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any]
class DatasourceCredentialDeletePayload(BaseModel):
credential_id: str
class DatasourceCredentialUpdatePayload(BaseModel):
credential_id: str
name: str | None = Field(default=None, max_length=100)
credentials: dict[str, Any] | None = None
class DatasourceCustomClientPayload(BaseModel):
client_params: dict[str, Any] | None = None
enable_oauth_custom_client: bool | None = None
class DatasourceDefaultPayload(BaseModel):
id: str
class DatasourceUpdateNamePayload(BaseModel):
credential_id: str
name: str = Field(max_length=100)
register_schema_models(
console_ns,
DatasourceCredentialPayload,
DatasourceCredentialDeletePayload,
DatasourceCredentialUpdatePayload,
DatasourceCustomClientPayload,
DatasourceDefaultPayload,
DatasourceUpdateNamePayload,
)
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url") @console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
class DatasourcePluginOAuthAuthorizationUrl(Resource): class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required @setup_required
@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_datasource = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>") @console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource): class DatasourceAuth(Resource):
@console_ns.expect(parser_datasource) @console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -138,7 +174,7 @@ class DatasourceAuth(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_datasource.parse_args() payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
@ -146,8 +182,8 @@ class DatasourceAuth(Resource):
datasource_provider_service.add_datasource_api_key_provider( datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider_id=datasource_provider_id, provider_id=datasource_provider_id,
credentials=args["credentials"], credentials=payload.credentials,
name=args["name"], name=payload.name,
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ValueError(str(ex)) raise ValueError(str(ex))
@ -169,14 +205,9 @@ class DatasourceAuth(Resource):
return {"result": datasources}, 200 return {"result": datasources}, 200
parser_datasource_delete = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource): class DatasourceAuthDeleteApi(Resource):
@console_ns.expect(parser_datasource_delete) @console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource):
plugin_id = datasource_provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name provider_name = datasource_provider_id.provider_name
args = parser_datasource_delete.parse_args() payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials( datasource_provider_service.remove_datasource_credentials(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=payload.credential_id,
provider=provider_name, provider=provider_name,
plugin_id=plugin_id, plugin_id=plugin_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_datasource_update = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource): class DatasourceAuthUpdateApi(Resource):
@console_ns.expect(parser_datasource_update) @console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
args = parser_datasource_update.parse_args() payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials( datasource_provider_service.update_datasource_credentials(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
auth_id=args["credential_id"], auth_id=payload.credential_id,
provider=datasource_provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
credentials=args.get("credentials", {}), credentials=payload.credentials or {},
name=args.get("name", None), name=payload.name,
) )
return {"result": "success"}, 201 return {"result": "success"}, 201
@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource):
return {"result": jsonable_encoder(datasources)}, 200 return {"result": jsonable_encoder(datasources)}, 200
parser_datasource_custom = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource): class DatasourceAuthOauthCustomClient(Resource):
@console_ns.expect(parser_datasource_custom) @console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_datasource_custom.parse_args() payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params( datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}), client_params=payload.client_params or {},
enabled=args.get("enable_oauth_custom_client", False), enabled=payload.enable_oauth_custom_client or False,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource): class DatasourceAuthDefaultApi(Resource):
@console_ns.expect(parser_default) @console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_default.parse_args() payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider( datasource_provider_service.set_default_datasource_provider(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
credential_id=args["id"], credential_id=payload.id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
parser_update_name = (
reqparse.RequestParser()
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource): class DatasourceUpdateProviderNameApi(Resource):
@console_ns.expect(parser_update_name) @console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource):
def post(self, provider_id: str): def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
args = parser_update_name.parse_args() payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name( datasource_provider_service.update_datasource_provider_name(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
name=args["name"], name=payload.name,
credential_id=args["credential_id"], credential_id=payload.credential_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -26,7 +26,7 @@ console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=D
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource): class DataSourceContentPreviewApi(Resource):
@console_ns.expect(console_ns.models[Parser.__name__], validate=True) @console_ns.expect(console_ns.models[Parser.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,9 +1,11 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description: str) -> str:
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
@console_ns.route("/rag/pipeline/templates") @console_ns.route("/rag/pipeline/templates")
class PipelineTemplateListApi(Resource): class PipelineTemplateListApi(Resource):
@setup_required @setup_required
@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource):
return pipeline_template, 200 return pipeline_template, 200
class Payload(BaseModel):
name: str = Field(..., min_length=1, max_length=40)
description: str = Field(default="", max_length=400)
icon_info: dict[str, object] | None = None
register_schema_models(console_ns, Payload)
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>") @console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
class CustomizedPipelineTemplateApi(Resource): class CustomizedPipelineTemplateApi(Resource):
@setup_required @setup_required
@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def patch(self, template_id: str): def patch(self, template_id: str):
parser = ( payload = Payload.model_validate(console_ns.payload or {})
reqparse.RequestParser() pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200 return 200
@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource):
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish") @console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
class PublishCustomizedPipelineTemplateApi(Resource): class PublishCustomizedPipelineTemplateApi(Resource):
@console_ns.expect(console_ns.models[Payload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@knowledge_pipeline_publish_enabled @knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str): def post(self, pipeline_id: str):
parser = ( payload = Payload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
.add_argument(
"description",
type=_validate_description_length,
nullable=True,
required=False,
default="",
)
.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args) rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
return {"result": "success"} return {"result": "success"}

View File

@ -1,8 +1,10 @@
from flask_restx import Resource, marshal, reqparse from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services
from controllers.common.schema import register_schema_model
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import ( from controllers.console.wraps import (
@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineDatasetImportPayload(BaseModel):
yaml_content: str
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
@console_ns.route("/rag/pipeline/dataset") @console_ns.route("/rag/pipeline/dataset")
class CreateRagPipelineDatasetApi(Resource): class CreateRagPipelineDatasetApi(Resource):
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = reqparse.RequestParser().add_argument( payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource):
), ),
permission=DatasetPermissionEnum.ONLY_ME, permission=DatasetPermissionEnum.ONLY_ME,
partial_member_list=None, partial_member_list=None,
yaml_content=args["yaml_content"], yaml_content=payload.yaml_content,
) )
try: try:
with Session(db.engine) as session: with Session(db.engine) as session:

View File

@ -1,11 +1,13 @@
import logging import logging
from typing import NoReturn from typing import Any, NoReturn
from flask import Response from flask import Response, request
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
DraftWorkflowNotExist, DraftWorkflowNotExist,
@ -33,19 +35,21 @@ logger = logging.getLogger(__name__)
def _create_pagination_parser(): def _create_pagination_parser():
parser = ( class PaginationQuery(BaseModel):
reqparse.RequestParser() page: int = Field(default=1, ge=1, le=100_000)
.add_argument( limit: int = Field(default=20, ge=1, le=100)
"page",
type=inputs.int_range(1, 100_000), register_schema_models(console_ns, PaginationQuery)
required=False,
default=1, return PaginationQuery
location="args",
help="the page of data requested",
) class WorkflowDraftVariablePatchPayload(BaseModel):
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") name: str | None = None
) value: Any | None = None
return parser
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]: def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource):
""" """
Get draft workflow Get draft workflow
""" """
parser = _create_pagination_parser() pagination = _create_pagination_parser()
args = parser.parse_args() query = pagination.model_validate(request.args.to_dict())
# fetch draft workflow by app_model # fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource):
) )
workflow_vars = draft_var_srv.list_variables_without_values( workflow_vars = draft_var_srv.list_variables_without_values(
app_id=pipeline.id, app_id=pipeline.id,
page=args.page, page=query.page,
limit=args.limit, limit=query.limit,
) )
return workflow_vars return workflow_vars
@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str): def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types: # Request payload for file types:
# #
@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource):
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4" # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# } # }
parser = (
reqparse.RequestParser()
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
)
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),
) )
args = parser.parse_args(strict=True) payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
variable = draft_var_srv.get_variable(variable_id=variable_id) variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None: if variable is None:

View File

@ -1,6 +1,9 @@
from flask_restx import Resource, marshal_with, reqparse # type: ignore from flask import request
from flask_restx import Resource, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import ( from controllers.console.wraps import (
@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportPayload(BaseModel):
mode: str
yaml_content: str | None = None
yaml_url: str | None = None
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
pipeline_id: str | None = None
class IncludeSecretQuery(BaseModel):
include_secret: str = Field(default="false")
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
@console_ns.route("/rag/pipelines/imports") @console_ns.route("/rag/pipelines/imports")
class RagPipelineImportApi(Resource): class RagPipelineImportApi(Resource):
@setup_required @setup_required
@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource):
@account_initialization_required @account_initialization_required
@edit_permission_required @edit_permission_required
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
def post(self): def post(self):
# Check user role first # Check user role first
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json")
.add_argument("yaml_content", type=str, location="json")
.add_argument("yaml_url", type=str, location="json")
.add_argument("name", type=str, location="json")
.add_argument("description", type=str, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("pipeline_id", type=str, location="json")
)
args = parser.parse_args()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource):
account = current_user account = current_user
result = import_service.import_rag_pipeline( result = import_service.import_rag_pipeline(
account=account, account=account,
import_mode=args["mode"], import_mode=payload.mode,
yaml_content=args.get("yaml_content"), yaml_content=payload.yaml_content,
yaml_url=args.get("yaml_url"), yaml_url=payload.yaml_url,
pipeline_id=args.get("pipeline_id"), pipeline_id=payload.pipeline_id,
dataset_name=args.get("name"), dataset_name=payload.name,
) )
session.commit() session.commit()
@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource):
@edit_permission_required @edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args") query = IncludeSecretQuery.model_validate(request.args.to_dict())
args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
export_service = RagPipelineDslService(session) export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl( result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=args["include_secret"] == "true" pipeline=pipeline, include_secret=query.include_secret == "true"
) )
return {"data": result}, 200 return {"data": result}, 200

View File

@ -1,14 +1,16 @@
import json import json
import logging import logging
from typing import cast from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request from flask import abort, request
from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore from flask_restx import Resource, marshal_with, reqparse # type: ignore
from flask_restx.inputs import int_range # type: ignore from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
ConversationCompletedError, ConversationCompletedError,
@ -36,7 +38,7 @@ from fields.workflow_run_fields import (
workflow_run_pagination_fields, workflow_run_pagination_fields,
) )
from libs import helper from libs import helper
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required from libs.login import current_account_with_tenant, current_user, login_required
from models import Account from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DraftWorkflowSyncPayload(BaseModel):
graph: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] | None = None
conversation_variables: list[dict[str, Any]] | None = None
rag_pipeline_variables: list[dict[str, Any]] | None = None
features: dict[str, Any] | None = None
class NodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class NodeRunRequiredPayload(BaseModel):
inputs: dict[str, Any]
class DatasourceNodeRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
credential_id: str | None = None
class DraftWorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
datasource_type: str
datasource_info_list: list[dict[str, Any]]
start_node_id: str
class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
is_preview: bool = False
response_mode: Literal["streaming", "blocking"] = "streaming"
original_document_id: str | None = None
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class NodeIdQuery(BaseModel):
node_id: str
class WorkflowRunQuery(BaseModel):
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
class DatasourceVariablesPayload(BaseModel):
datasource_type: str
datasource_info: dict[str, Any]
start_node_id: str
start_node_title: str
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
NodeRunPayload,
NodeRunRequiredPayload,
DatasourceNodeRunPayload,
DraftWorkflowRunPayload,
PublishedWorkflowRunPayload,
DefaultBlockConfigQuery,
WorkflowListQuery,
WorkflowUpdatePayload,
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
class DraftRagPipelineApi(Resource): class DraftRagPipelineApi(Resource):
@setup_required @setup_required
@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource):
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type: if "application/json" in content_type:
parser = ( payload_dict = console_ns.payload or {}
reqparse.RequestParser()
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=False, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
data = json.loads(request.data.decode("utf-8")) data = json.loads(request.data.decode("utf-8"))
@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource):
if not isinstance(data.get("graph"), dict): if not isinstance(data.get("graph"), dict):
raise ValueError("graph is not a dict") raise ValueError("graph is not a dict")
args = { payload_dict = {
"graph": data.get("graph"), "graph": data.get("graph"),
"features": data.get("features"), "features": data.get("features"),
"hash": data.get("hash"), "hash": data.get("hash"),
@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource):
else: else:
abort(415) abort(415)
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
try: try:
environment_variables_list = args.get("environment_variables") or [] environment_variables_list = payload.environment_variables or []
environment_variables = [ environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
] ]
conversation_variables_list = args.get("conversation_variables") or [] conversation_variables_list = payload.conversation_variables or []
conversation_variables = [ conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
] ]
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.sync_draft_workflow( workflow = rag_pipeline_service.sync_draft_workflow(
pipeline=pipeline, pipeline=pipeline,
graph=args["graph"], graph=payload.graph,
unique_hash=args.get("hash"), unique_hash=payload.hash,
account=current_user, account=current_user,
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
rag_pipeline_variables=args.get("rag_pipeline_variables") or [], rag_pipeline_variables=payload.rag_pipeline_variables or [],
) )
except WorkflowHashNotEqualError: except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync() raise DraftWorkflowNotSync()
@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource):
} }
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource): class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.expect(parser_run) @console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_run.parse_args() payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try: try:
response = PipelineGenerateService.generate_single_iteration( response = PipelineGenerateService.generate_single_iteration(
@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource): class RagPipelineDraftRunLoopNodeApi(Resource):
@console_ns.expect(parser_run) @console_ns.expect(console_ns.models[NodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_run.parse_args() payload = NodeRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try: try:
response = PipelineGenerateService.generate_single_loop( response = PipelineGenerateService.generate_single_loop(
@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
raise InternalServerError() raise InternalServerError()
parser_draft_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource): class DraftRagPipelineRunApi(Resource):
@console_ns.expect(parser_draft_run) @console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_draft_run.parse_args() payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
try: try:
response = PipelineGenerateService.generate( response = PipelineGenerateService.generate(
@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description) raise InvokeRateLimitHttpError(ex.description)
parser_published_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info_list", type=list, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
.add_argument("original_document_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource): class PublishedRagPipelineRunApi(Resource):
@console_ns.expect(parser_published_run) @console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_published_run.parse_args() payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = args["response_mode"] == "streaming" streaming = payload.response_mode == "streaming"
try: try:
response = PipelineGenerateService.generate( response = PipelineGenerateService.generate(
pipeline=pipeline, pipeline=pipeline,
user=current_user, user=current_user,
args=args, args=args,
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED, invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
streaming=streaming, streaming=streaming,
) )
@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource):
# #
# return result # return result
# #
parser_rag_run = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("credential_id", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource): class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run) @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_rag_run.parse_args() payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response( return helper.compact_generate_response(
@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node( rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, pipeline=pipeline,
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=payload.inputs,
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=payload.datasource_type,
is_published=False, is_published=False,
credential_id=args.get("credential_id"), credential_id=payload.credential_id,
) )
) )
) )
@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource): class RagPipelineDraftDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run) @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@edit_permission_required @edit_permission_required
@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_rag_run.parse_args() payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
return helper.compact_generate_response( return helper.compact_generate_response(
@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
rag_pipeline_service.run_datasource_workflow_node( rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, pipeline=pipeline,
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=payload.inputs,
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=payload.datasource_type,
is_published=False, is_published=False,
credential_id=args.get("credential_id"), credential_id=payload.credential_id,
) )
) )
) )
parser_run_api = reqparse.RequestParser().add_argument(
"inputs", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource): class RagPipelineDraftNodeRunApi(Resource):
@console_ns.expect(parser_run_api) @console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@edit_permission_required @edit_permission_required
@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_run_api.parse_args() payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
inputs = payload.inputs
inputs = args.get("inputs")
if inputs == None:
raise ValueError("missing inputs")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node( workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
return rag_pipeline_service.get_default_block_configs() return rag_pipeline_service.get_default_block_configs()
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource): class DefaultRagPipelineBlockConfigApi(Resource):
@console_ns.expect(parser_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource):
""" """
Get default block config Get default block config
""" """
args = parser_default.parse_args() query = DefaultBlockConfigQuery.model_validate(request.args.to_dict())
q = args.get("q")
filters = None filters = None
if q: if query.q:
try: try:
filters = json.loads(args.get("q", "")) filters = json.loads(query.q)
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError("Invalid filters") raise ValueError("Invalid filters")
@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters) return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
parser_wf = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource): class PublishedAllRagPipelineApi(Resource):
@console_ns.expect(parser_wf)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_wf.parse_args() query = WorkflowListQuery.model_validate(request.args.to_dict())
page = args["page"]
limit = args["limit"] page = query.page
user_id = args.get("user_id") limit = query.limit
named_only = args.get("named_only", False) user_id = query.user_id
named_only = query.named_only
if user_id: if user_id:
if user_id != current_user.id: if user_id != current_user.id:
raise Forbidden() raise Forbidden()
user_id = cast(str, user_id)
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource):
} }
parser_wf_id = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource): class RagPipelineByIdApi(Resource):
@console_ns.expect(parser_wf_id)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource):
# Check permission # Check permission
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_wf_id.parse_args() payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
# Prepare update data
update_data = {}
if args.get("marked_name") is not None:
update_data["marked_name"] = args["marked_name"]
if args.get("marked_comment") is not None:
update_data["marked_comment"] = args["marked_comment"]
if not update_data: if not update_data:
return {"message": "No valid fields to update"}, 400 return {"message": "No valid fields to update"}, 400
@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource):
return workflow return workflow
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource): class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return { return {
@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource): class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False) variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
return { return {
@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource): class DraftRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource):
""" """
Get first step parameters of rag pipeline Get first step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
return { return {
@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource): class DraftRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource):
""" """
Get second step parameters of rag pipeline Get second step parameters of rag pipeline
""" """
args = parser_parameters.parse_args() query = NodeIdQuery.model_validate(request.args.to_dict())
node_id = args.get("node_id") node_id = query.node_id
if not node_id:
raise ValueError("Node ID is required")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True) variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource):
} }
parser_wf_run = (
reqparse.RequestParser()
.add_argument("last_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource): class RagPipelineWorkflowRunListApi(Resource):
@console_ns.expect(parser_wf_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource):
""" """
Get workflow run list Get workflow run list
""" """
args = parser_wf_run.parse_args() query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": str(query.last_id) if query.last_id else None,
"limit": query.limit,
}
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args) result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource):
return result return result
parser_var = (
reqparse.RequestParser()
.add_argument("datasource_type", type=str, required=True, location="json")
.add_argument("datasource_info", type=dict, required=True, location="json")
.add_argument("start_node_id", type=str, required=True, location="json")
.add_argument("start_node_title", type=str, required=True, location="json")
)
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource): class RagPipelineDatasourceVariableApi(Resource):
@console_ns.expect(parser_var) @console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource):
Set datasource variables Set datasource variables
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = parser_var.parse_args() args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
workflow_node_execution = rag_pipeline_service.set_datasource_variables( workflow_node_execution = rag_pipeline_service.set_datasource_variables(
@ -1004,6 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins() recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
return recommended_plugins return recommended_plugins

View File

@ -1,5 +1,10 @@
from flask_restx import Resource, fields, reqparse from typing import Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel
from controllers.common.schema import register_schema_models
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
@ -7,48 +12,35 @@ from libs.login import login_required
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
class WebsiteCrawlPayload(BaseModel):
provider: Literal["firecrawl", "watercrawl", "jinareader"]
url: str
options: dict[str, object]
class WebsiteCrawlStatusQuery(BaseModel):
provider: Literal["firecrawl", "watercrawl", "jinareader"]
register_schema_models(console_ns, WebsiteCrawlPayload, WebsiteCrawlStatusQuery)
@console_ns.route("/website/crawl") @console_ns.route("/website/crawl")
class WebsiteCrawlApi(Resource): class WebsiteCrawlApi(Resource):
@console_ns.doc("crawl_website") @console_ns.doc("crawl_website")
@console_ns.doc(description="Crawl website content") @console_ns.doc(description="Crawl website content")
@console_ns.expect( @console_ns.expect(console_ns.models[WebsiteCrawlPayload.__name__])
console_ns.model(
"WebsiteCrawlRequest",
{
"provider": fields.String(
required=True,
description="Crawl provider (firecrawl/watercrawl/jinareader)",
enum=["firecrawl", "watercrawl", "jinareader"],
),
"url": fields.String(required=True, description="URL to crawl"),
"options": fields.Raw(required=True, description="Crawl options"),
},
)
)
@console_ns.response(200, "Website crawl initiated successfully") @console_ns.response(200, "Website crawl initiated successfully")
@console_ns.response(400, "Invalid crawl parameters") @console_ns.response(400, "Invalid crawl parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = ( payload = WebsiteCrawlPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument(
"provider",
type=str,
choices=["firecrawl", "watercrawl", "jinareader"],
required=True,
nullable=True,
location="json",
)
.add_argument("url", type=str, required=True, nullable=True, location="json")
.add_argument("options", type=dict, required=True, nullable=True, location="json")
)
args = parser.parse_args()
# Create typed request and validate # Create typed request and validate
try: try:
api_request = WebsiteCrawlApiRequest.from_args(args) api_request = WebsiteCrawlApiRequest.from_args(payload.model_dump())
except ValueError as e: except ValueError as e:
raise WebsiteCrawlError(str(e)) raise WebsiteCrawlError(str(e))
@ -65,6 +57,7 @@ class WebsiteCrawlStatusApi(Resource):
@console_ns.doc("get_crawl_status") @console_ns.doc("get_crawl_status")
@console_ns.doc(description="Get website crawl status") @console_ns.doc(description="Get website crawl status")
@console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"}) @console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
@console_ns.expect(console_ns.models[WebsiteCrawlStatusQuery.__name__])
@console_ns.response(200, "Crawl status retrieved successfully") @console_ns.response(200, "Crawl status retrieved successfully")
@console_ns.response(404, "Crawl job not found") @console_ns.response(404, "Crawl job not found")
@console_ns.response(400, "Invalid provider") @console_ns.response(400, "Invalid provider")
@ -72,14 +65,11 @@ class WebsiteCrawlStatusApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, job_id: str): def get(self, job_id: str):
parser = reqparse.RequestParser().add_argument( args = WebsiteCrawlStatusQuery.model_validate(request.args.to_dict())
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args()
# Create typed request and validate # Create typed request and validate
try: try:
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id) api_request = WebsiteCrawlStatusApiRequest.from_args(args.model_dump(), job_id)
except ValueError as e: except ValueError as e:
raise WebsiteCrawlError(str(e)) raise WebsiteCrawlError(str(e))

View File

@ -1,9 +1,11 @@
import logging import logging
from flask import request from flask import request
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.common.schema import register_schema_model
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
AudioTooLargeError, AudioTooLargeError,
@ -31,6 +33,16 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload)
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/audio-to-text", "/installed-apps/<uuid:installed_app_id>/audio-to-text",
endpoint="installed_app_audio", endpoint="installed_app_audio",
@ -76,23 +88,15 @@ class ChatAudioApi(InstalledAppResource):
endpoint="installed_app_text", endpoint="installed_app_text",
) )
class ChatTextApi(InstalledAppResource): class ChatTextApi(InstalledAppResource):
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
def post(self, installed_app): def post(self, installed_app):
from flask_restx import reqparse
app_model = installed_app.app app_model = installed_app.app
try: try:
parser = ( payload = TextToAudioPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = args.get("message_id", None) message_id = payload.message_id
text = args.get("text", None) text = payload.text
voice = args.get("voice", None) voice = payload.voice
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
return response return response

View File

@ -1,9 +1,12 @@
import logging import logging
from typing import Any, Literal
from uuid import UUID
from flask_restx import reqparse from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.common.schema import register_schema_models
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,
@ -25,7 +28,6 @@ from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from libs.login import current_user from libs.login import current_user
from models import Account from models import Account
from models.model import AppMode from models.model import AppMode
@ -38,28 +40,56 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CompletionMessageExplorePayload(BaseModel):
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = Field(default="explore_app")
class ChatMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str
files: list[dict[str, Any]] | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = Field(default="explore_app")
@field_validator("conversation_id", "parent_message_id", mode="before")
@classmethod
def normalize_uuid(cls, value: str | UUID | None) -> str | None:
"""
Accept blank IDs and validate UUID format when provided.
"""
if not value:
return None
try:
return helper.uuid_value(value)
except ValueError as exc:
raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
# define completion api for user # define completion api for user
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/completion-messages", "/installed-apps/<uuid:installed_app_id>/completion-messages",
endpoint="installed_app_completion", endpoint="installed_app_completion",
) )
class CompletionApi(InstalledAppResource): class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION: if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError() raise NotCompletionAppError()
parser = ( payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = payload.response_mode == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = naive_utc_now() installed_app.last_used_at = naive_utc_now()
@ -123,22 +153,15 @@ class CompletionStopApi(InstalledAppResource):
endpoint="installed_app_chat_completion", endpoint="installed_app_chat_completion",
) )
class ChatApi(InstalledAppResource): class ChatApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = ( payload = ChatMessagePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser() args = payload.model_dump(exclude_none=True)
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
)
args = parser.parse_args()
args["auto_generate_name"] = False args["auto_generate_name"] = False

View File

@ -1,14 +1,18 @@
from flask_restx import marshal_with, reqparse from typing import Any
from flask_restx.inputs import int_range
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value from libs.helper import UUIDStrOrEmpty
from libs.login import current_user from libs.login import current_user
from models import Account from models import Account
from models.model import AppMode from models.model import AppMode
@ -19,29 +23,51 @@ from services.web_conversation_service import WebConversationService
from .. import console_ns from .. import console_ns
class ConversationListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/conversations", "/installed-apps/<uuid:installed_app_id>/conversations",
endpoint="installed_app_conversations", endpoint="installed_app_conversations",
) )
class ConversationListApi(InstalledAppResource): class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app): def get(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
parser = ( raw_args: dict[str, Any] = {
reqparse.RequestParser() "last_id": request.args.get("last_id"),
.add_argument("last_id", type=uuid_value, location="args") "limit": request.args.get("limit", default=20, type=int),
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") "pinned": request.args.get("pinned"),
.add_argument("pinned", type=str, choices=["true", "false", None], location="args") }
) if raw_args["last_id"] is None:
args = parser.parse_args() raw_args["last_id"] = None
pinned_value = raw_args["pinned"]
pinned = None if isinstance(pinned_value, str):
if "pinned" in args and args["pinned"] is not None: raw_args["pinned"] = pinned_value == "true"
pinned = args["pinned"] == "true" args = ConversationListQuery.model_validate(raw_args)
try: try:
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
@ -51,10 +77,10 @@ class ConversationListApi(InstalledAppResource):
session=session, session=session,
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
last_id=args["last_id"], last_id=str(args.last_id) if args.last_id else None,
limit=args["limit"], limit=args.limit,
invoke_from=InvokeFrom.EXPLORE, invoke_from=InvokeFrom.EXPLORE,
pinned=pinned, pinned=args.pinned,
) )
except LastConversationNotExistsError: except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
@ -88,6 +114,7 @@ class ConversationApi(InstalledAppResource):
) )
class ConversationRenameApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
@ -96,18 +123,13 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
parser = ( payload = ConversationRenamePayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("name", type=str, required=False, location="json")
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
)
args = parser.parse_args()
try: try:
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance") raise ValueError("current_user must be an Account instance")
return ConversationService.rename( return ConversationService.rename(
app_model, conversation_id, current_user, args["name"], args["auto_generate"] app_model, conversation_id, current_user, payload.name, payload.auto_generate
) )
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@ -2,7 +2,8 @@ import logging
from typing import Any from typing import Any
from flask import request from flask import request
from flask_restx import Resource, inputs, marshal_with, reqparse from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from sqlalchemy import and_, select from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@ -18,6 +19,15 @@ from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
class InstalledAppCreatePayload(BaseModel):
app_id: str
class InstalledAppUpdatePayload(BaseModel):
is_pinned: bool | None = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -105,26 +115,25 @@ class InstalledAppsListApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
def post(self): def post(self):
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id") payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == args["app_id"]).first() app = db.session.query(App).where(App.id == payload.app_id).first()
if app is None: if app is None:
raise NotFound("App not found") raise NotFound("App entity not found")
if not app.is_public: if not app.is_public:
raise Forbidden("You can't install a non-public app") raise Forbidden("You can't install a non-public app")
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id))
.first() .first()
) )
@ -133,7 +142,7 @@ class InstalledAppsListApi(Resource):
recommended_app.install_count += 1 recommended_app.install_count += 1
new_installed_app = InstalledApp( new_installed_app = InstalledApp(
app_id=args["app_id"], app_id=payload.app_id,
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id, app_owner_tenant_id=app.tenant_id,
is_pinned=False, is_pinned=False,
@ -163,12 +172,11 @@ class InstalledAppApi(InstalledAppResource):
return {"result": "success", "message": "App uninstalled successfully"}, 204 return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app): def patch(self, installed_app):
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean) payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
args = parser.parse_args()
commit_args = False commit_args = False
if "is_pinned" in args: if payload.is_pinned is not None:
installed_app.is_pinned = args["is_pinned"] installed_app.is_pinned = payload.is_pinned
commit_args = True commit_args = True
if commit_args: if commit_args:

View File

@ -1,9 +1,12 @@
import logging import logging
from typing import Literal
from flask_restx import marshal_with, reqparse from flask import request
from flask_restx.inputs import int_range from flask_restx import marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console.app.error import ( from controllers.console.app.error import (
AppMoreLikeThisDisabledError, AppMoreLikeThisDisabledError,
CompletionRequestError, CompletionRequestError,
@ -22,7 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
@ -40,12 +43,31 @@ from .. import console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"]
register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery)
@console_ns.route( @console_ns.route(
"/installed-apps/<uuid:installed_app_id>/messages", "/installed-apps/<uuid:installed_app_id>/messages",
endpoint="installed_app_messages", endpoint="installed_app_messages",
) )
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app): def get(self, installed_app):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
@ -53,18 +75,15 @@ class MessageListApi(InstalledAppResource):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
args = MessageListQuery.model_validate(request.args.to_dict())
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
try: try:
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] app_model,
current_user,
str(args.conversation_id),
str(args.first_id) if args.first_id else None,
args.limit,
) )
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -77,26 +96,22 @@ class MessageListApi(InstalledAppResource):
endpoint="installed_app_message_feedback", endpoint="installed_app_message_feedback",
) )
class MessageFeedbackApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
def post(self, installed_app, message_id): def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
message_id = str(message_id) message_id = str(message_id)
parser = ( payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
reqparse.RequestParser()
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
.add_argument("content", type=str, location="json")
)
args = parser.parse_args()
try: try:
MessageService.create_feedback( MessageService.create_feedback(
app_model=app_model, app_model=app_model,
message_id=message_id, message_id=message_id,
user=current_user, user=current_user,
rating=args.get("rating"), rating=payload.rating,
content=args.get("content"), content=payload.content,
) )
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -109,6 +124,7 @@ class MessageFeedbackApi(InstalledAppResource):
endpoint="installed_app_more_like_this", endpoint="installed_app_more_like_this",
) )
class MessageMoreLikeThisApi(InstalledAppResource): class MessageMoreLikeThisApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
app_model = installed_app.app app_model = installed_app.app
@ -117,12 +133,9 @@ class MessageMoreLikeThisApi(InstalledAppResource):
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser().add_argument( args = MoreLikeThisQuery.model_validate(request.args.to_dict())
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args.response_mode == "streaming"
try: try:
response = AppGenerateService.generate_more_like_this( response = AppGenerateService.generate_more_like_this(

View File

@ -1,4 +1,6 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from constants.languages import languages from constants.languages import languages
from controllers.console import console_ns from controllers.console import console_ns
@ -35,20 +37,26 @@ recommended_app_list_fields = {
} }
parser_apps = reqparse.RequestParser().add_argument("language", type=str, location="args") class RecommendedAppsQuery(BaseModel):
language: str | None = Field(default=None)
console_ns.schema_model(
RecommendedAppsQuery.__name__,
RecommendedAppsQuery.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/explore/apps") @console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource): class RecommendedAppListApi(Resource):
@console_ns.expect(parser_apps) @console_ns.expect(console_ns.models[RecommendedAppsQuery.__name__])
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(recommended_app_list_fields) @marshal_with(recommended_app_list_fields)
def get(self): def get(self):
# language args # language args
args = parser_apps.parse_args() args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
language = args.language
language = args.get("language")
if language and language in languages: if language and language in languages:
language_prefix = language language_prefix = language
elif current_user and current_user.interface_language: elif current_user and current_user.interface_language:

Some files were not shown because too many files have changed in this diff Show More